From a1b8ecf667efa94a88840f9a77ecad4bf688600c Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 02:36:38 -0700 Subject: [PATCH 01/15] Initial implementation of the `transparent` mode. --- Cargo.lock | 10 + Cargo.toml | 1 + Dockerfile | 15 +- Dockerfile.dev | 19 +- README.md | 4 + pgcat.toml | 6 +- src/client.rs | 700 ++++++++++++++++++++++++++++-------------- src/client_xact.rs | 581 +++++++++++++++++++++++++++++++++++ src/config.rs | 47 ++- src/lib.rs | 3 + src/messages.rs | 56 +++- src/pool.rs | 2 +- src/query_messages.rs | 456 +++++++++++++++++++++++++++ src/server.rs | 88 +++--- src/server_xact.rs | 544 ++++++++++++++++++++++++++++++++ 15 files changed, 2239 insertions(+), 293 deletions(-) create mode 100644 src/client_xact.rs create mode 100644 src/query_messages.rs create mode 100644 src/server_xact.rs diff --git a/Cargo.lock b/Cargo.lock index 929f1a81..56015e4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,7 @@ dependencies = [ "tracing", "tracing-subscriber", "trust-dns-resolver", + "uuid", "webpki-roots", ] @@ -1899,6 +1900,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 805a4c7a..3fd69c55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ itertools = "0.10" clap = { version = "4.3.1", features = ["derive", "env"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]} +uuid = { version = "1.4.1", features = ["v4"] } [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/Dockerfile b/Dockerfile index f2d58062..57c4ec11 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,19 @@ -FROM rust:1-slim-bookworm AS builder +FROM ubuntu:23.04 AS builder RUN apt-get update && \ - apt-get install -y build-essential + apt-get install -y build-essential curl + +# Get Rust +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y + +ENV PATH="/root/.cargo/bin:${PATH}" COPY . /app WORKDIR /app -RUN cargo build --release +RUN /root/.cargo/bin/cargo build --release -FROM debian:bookworm-slim -COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat +FROM ubuntu:23.04 +COPY --from=builder /app/target/release /usr/bin/ COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml WORKDIR /etc/pgcat ENV RUST_LOG=info diff --git a/Dockerfile.dev b/Dockerfile.dev index a4b8d0ed..1b3cd3c1 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -1,7 +1,16 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef +FROM ubuntu:23.04 AS chef RUN apt-get update && \ - apt-get install -y build-essential + apt-get install -y build-essential curl + +# Get Rust +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y + +ENV PATH="/root/.cargo/bin:${PATH}" + +# We only pay the installation cost once, +# it will be cached from the second build onwards +RUN cargo install cargo-chef WORKDIR /app @@ -15,10 +24,10 @@ COPY --from=planner /app/recipe.json recipe.json RUN cargo chef cook --release --recipe-path recipe.json # Build application COPY . . -RUN cargo build +RUN cargo build --release -FROM debian:bookworm-slim -COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat +FROM ubuntu:23.04 +COPY --from=builder /app/target/release /usr/bin/ COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml WORKDIR /etc/pgcat ENV RUST_LOG=info diff --git a/README.md b/README.md index ae310cde..065d2c96 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal |-------------|------------|--------------| | Transaction pooling | **Stable** | Identical to PgBouncer with notable improvements for handling bad clients and abandoned transactions. | | Session pooling | **Stable** | Identical to PgBouncer. | +| Transparent pooling | **Stable** | A new pooling mechanism that enables transparent (distributed) transactions. | | Multi-threaded runtime | **Stable** | Using Tokio asynchronous runtime, the pooler takes advantage of multicore machines. | | Load balancing of read queries | **Stable** | Queries are automatically load balanced between replicas and the primary. | | Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. | @@ -145,6 +146,9 @@ In transaction mode, a client talks to one server for the duration of a single t This mode is enabled by default. +### Transparent mode +In transparent mode, a client talks to one or more servers for the duration of a single transaction; once it's over, the servers are returned to the pool. `SET SHARD` and `SET SHARDING KEY` statements **are** supported, but prepared statements, other `SET` statements and advisory locks **are not** supported. + ### Load balancing of read queries All queries are load balanced against the configured servers using either the random or least open connections algorithms. The most straightforward configuration example would be to put this pooler in front of several replicas and let it load balance all queries. diff --git a/pgcat.toml b/pgcat.toml index 772a1365..ebf7a4f7 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -143,7 +143,7 @@ result = [ # Pool mode (see PgBouncer docs for more). # `session` one server connection per connected client # `transaction` one server connection per client transaction -pool_mode = "transaction" +pool_mode = "transparent" # Load balancing mode # `random` selects the server at random @@ -227,7 +227,7 @@ connect_timeout = 3000 [pool.sharded_db.plugins] [pools.sharded_db.plugins.prewarmer] -enabled = true +enabled = false queries = [ "SELECT pg_prewarm('pgbench_accounts')", ] @@ -280,8 +280,6 @@ username = "sharding_user" # if `server_password` is not set. password = "sharding_user" -pool_mode = "transaction" - # PostgreSQL username used to connect to the server. # server_username = "another_user" diff --git a/src/client.rs b/src/client.rs index 4b281121..16ee067e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -14,8 +14,10 @@ use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::auth_passthrough::refetch_auth_hash; +use crate::client_xact::*; use crate::config::{ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, + Role, }; use crate::constants::*; use crate::messages::*; @@ -23,6 +25,7 @@ use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; +use crate::server_xact::*; use crate::stats::{ClientStats, ServerStats}; use crate::tls::Tls; @@ -47,21 +50,21 @@ pub struct Client { /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: T, + pub(crate) write: T, /// Internal buffer, where we place messages until we have to flush /// them to the backend. buffer: BytesMut, /// Address - addr: std::net::SocketAddr, + pub(crate) addr: std::net::SocketAddr, /// The client was started with the sole reason to cancel another running query. cancel_mode: bool, /// In transaction mode, the connection is released after each transaction. /// Session mode has slightly higher throughput per client, but lower capacity. - transaction_mode: bool, + client_pool_mode: PoolMode, /// For query cancellation, the client is given a random process ID and secret on startup. process_id: i32, @@ -76,7 +79,7 @@ pub struct Client { parameters: HashMap, /// Statistics related to this client - stats: Arc, + pub(crate) stats: Arc, /// Clients want to talk to admin database. admin: bool, @@ -87,6 +90,9 @@ pub struct Client { /// Last server process stats we talked to. last_server_stats: Option>, + /// Last server key we talked to. + pub(crate) last_server_key: Option<(usize, Option)>, + /// Connected to server connected_to_server: bool, @@ -97,13 +103,15 @@ pub struct Client { username: String, /// Server startup and session parameters that we're going to track - server_parameters: ServerParameters, + pub(crate) server_parameters: ServerParameters, /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, /// Prepared statements prepared_statements: HashMap, + + pub(crate) xact_info: TransactionMetaData, } /// Client entrypoint. @@ -518,7 +526,7 @@ where }; // Authenticate admin user. - let (transaction_mode, mut server_parameters) = if admin { + let (client_pool_mode, mut server_parameters) = if admin { let config = get_config(); // Compare server and client hashes. @@ -537,7 +545,7 @@ where return Err(error); } - (false, generate_server_parameters_for_admin()) + (PoolMode::Session, generate_server_parameters_for_admin()) } // Authenticate normal user. else { @@ -547,7 +555,7 @@ where error_response( &mut write, &format!( - "No pool configured for database: {:?}, user: {:?}", + "No pool configured for database: {:?}, user: {:?} (in startup)", pool_name, username ), ) @@ -649,7 +657,7 @@ where } } - let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; + let client_pool_mode = pool.settings.pool_mode; // If the pool hasn't been validated yet, // connect to the servers and figure out what's what. @@ -670,7 +678,7 @@ where } } - (transaction_mode, pool.server_parameters()) + (client_pool_mode, pool.server_parameters()) }; // Update the parameters to merge what the application sent and what's originally on the server @@ -698,7 +706,7 @@ where addr, buffer: BytesMut::with_capacity(8196), cancel_mode: false, - transaction_mode, + client_pool_mode, process_id, secret_key, client_server_map, @@ -707,12 +715,14 @@ where admin, last_address_id: None, last_server_stats: None, + last_server_key: None, pool_name: pool_name.clone(), username: username.clone(), server_parameters, shutdown, connected_to_server: false, prepared_statements: HashMap::new(), + xact_info: Default::default(), }) } @@ -733,7 +743,7 @@ where addr, buffer: BytesMut::with_capacity(8196), cancel_mode: true, - transaction_mode: false, + client_pool_mode: PoolMode::Session, process_id, secret_key, client_server_map, @@ -742,12 +752,14 @@ where admin: false, last_address_id: None, last_server_stats: None, + last_server_key: None, pool_name: String::from("undefined"), username: String::from("undefined"), server_parameters: ServerParameters::new(), shutdown, connected_to_server: false, prepared_statements: HashMap::new(), + xact_info: Default::default(), }) } @@ -804,8 +816,8 @@ where // or issue commands for our sharding and server selection protocol. loop { trace!( - "Client idle, waiting for message, transaction mode: {}", - self.transaction_mode + "Client idle, waiting for message, pool mode: {:?}", + self.client_pool_mode ); // Should we rewrite prepared statements and bind messages? @@ -1000,23 +1012,107 @@ where query_router.update_pool_settings(pool.settings.clone()); - let current_shard = query_router.shard(); + let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { + 0 => tokio::time::Duration::MAX, + timeout => tokio::time::Duration::from_millis(timeout), + }; + + let mut all_conns: HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + > = HashMap::new(); + let mut initial_message = Some(message); + + // Reset transaction state, as we are entering a new transaction loop. + reset_client_xact(self); + + let mut should_continue_in_outer_loop = false; + + // Transaction loop. Multiple queries can be issued by the client here. + // The connection belongs to the client until the transaction is over, + // or until the client disconnects if we are in session mode. + // + // If the client is in session mode, no more custom protocol + // commands will be accepted. + loop { + let mut message = match initial_message { + None => { + trace!("Waiting for message inside transaction or in session mode"); + + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); + + match tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) + .await + { + Ok(Ok(message)) => message, + Ok(Err(err)) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + self.stats.disconnect(); + for (_, mut conn) in all_conns { + let server = &mut *conn.0; + server.checkin_cleanup().await?; + } + + return Err(err); + } + Err(_) => { + // Client idle in transaction timeout + error_response_with_state( + &mut self.write, + "idle transaction timeout", + self.xact_info.state(), + ) + .await?; + error!( + "Client idle in transaction timeout: \ + {{ \ + pool_name: {}, \ + username: {}, \ + shard: {:?}, \ + role: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); + + break; + } + } + } + + Some(message) => { + initial_message = None; + message + } + }; + + assign_client_transaction_state(self, &all_conns); - // Handle all custom protocol commands, if any. - match query_router.try_execute_command(&message) { - // Normal query, not a custom command. - None => (), + if all_conns.is_empty() || self.is_transparent_mode() { + let current_shard = query_router.shard(); - // SET SHARD TO - Some((Command::SetShard, _)) => { - match query_router.shard() { + // Handle all custom protocol commands, if any. + match query_router.try_execute_command(&message) { + // Normal query, not a custom command. None => (), - Some(selected_shard) => { - if selected_shard >= pool.shards() { - // Bad shard number, send error message to client. - query_router.set_shard(current_shard); - error_response( + // SET SHARD TO + Some((Command::SetShard, _)) => { + match query_router.shard() { + None => (), + Some(selected_shard) => { + if selected_shard >= pool.shards() { + // Bad shard number, send error message to client. + query_router.set_shard(current_shard); + + error_response_with_state( &mut self.write, &format!( "shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)", @@ -1024,146 +1120,241 @@ where pool.shards(), current_shard, ), + self.xact_info.state(), ) .await?; + } else { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARD", + self.xact_info.state(), + ) + .await?; + } + } + } + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; } else { - custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + continue; } } - } - continue; - } - // SET PRIMARY READS TO - Some((Command::SetPrimaryReads, _)) => { - custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; - continue; - } - - // SET SHARDING KEY TO - Some((Command::SetShardingKey, _)) => { - custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; - continue; - } + // SET PRIMARY READS TO + Some((Command::SetPrimaryReads, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET PRIMARY READS", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } - // SET SERVER ROLE TO - Some((Command::SetServerRole, _)) => { - custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; - continue; - } + // SET SHARDING KEY TO + Some((Command::SetShardingKey, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARDING KEY", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } - // SHOW SERVER ROLE - Some((Command::ShowServerRole, value)) => { - show_response(&mut self.write, "server role", &value).await?; - continue; - } + // SET SERVER ROLE TO + Some((Command::SetServerRole, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SERVER ROLE", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } - // SHOW SHARD - Some((Command::ShowShard, value)) => { - show_response(&mut self.write, "shard", &value).await?; - continue; - } + // SHOW SERVER ROLE + Some((Command::ShowServerRole, value)) => { + show_response(&mut self.write, "server role", &value).await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } - // SHOW PRIMARY READS - Some((Command::ShowPrimaryReads, value)) => { - show_response(&mut self.write, "primary reads", &value).await?; - continue; - } - }; + // SHOW SHARD + Some((Command::ShowShard, value)) => { + show_response(&mut self.write, "shard", &value).await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } - debug!("Waiting for connection from pool"); - if !self.admin { - self.stats.waiting(); - } + // SHOW PRIMARY READS + Some((Command::ShowPrimaryReads, value)) => { + show_response(&mut self.write, "primary reads", &value).await?; + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } + }; - // Grab a server from the pool. - let connection = match pool - .get(query_router.shard(), query_router.role(), &self.stats) - .await - { - Ok(conn) => { - debug!("Got connection from pool"); - conn - } - Err(err) => { - // Client is attempting to get results from the server, - // but we were unable to grab a connection from the pool - // We'll send back an error message and clean the extended - // protocol buffer - self.stats.idle(); - - if message[0] as char == 'S' { - error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); + debug!("Waiting for connection from pool"); + if !self.admin { + self.stats.waiting(); } - error_response( - &mut self.write, - format!("could not get connection from the pool - {}", err).as_str(), - ) - .await?; - - error!( - "Could not get connection from pool: \ - {{ \ - pool_name: {:?}, \ - username: {:?}, \ - shard: {:?}, \ - role: \"{:?}\", \ - error: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role(), - err - ); - - continue; - } - }; - - let mut reference = connection.0; - let address = connection.1; - let server = &mut *reference; + let server_key = (query_router.shard().unwrap_or(0), query_router.role()); + let mut conn_opt = all_conns.get_mut(&server_key); + if conn_opt.is_none() { + // Grab a server from the pool. + let connection = match pool + .get(query_router.shard(), query_router.role(), &self.stats) + .await + { + Ok(conn) => { + debug!("Got connection from pool"); + conn + } + Err(err) => { + // Client is attempting to get results from the server, + // but we were unable to grab a connection from the pool + // We'll send back an error message and clean the extended + // protocol buffer + self.stats.idle(); - // Server is assigned to the client in case the client wants to - // cancel a query later. - server.claim(self.process_id, self.secret_key); - self.connected_to_server = true; + if message[0] as char == 'S' { + error!("Got Sync message but failed to get a connection from the pool"); + self.buffer.clear(); + } - // Update statistics - self.stats.active(); + error_response_with_state( + &mut self.write, + format!("could not get connection from the pool - {}", err) + .as_str(), + self.xact_info.state(), + ) + .await?; - self.last_address_id = Some(address.id); - self.last_server_stats = Some(server.stats()); + error!( + "Could not get connection from pool: \ + {{ \ + pool_name: {:?}, \ + username: {:?}, \ + shard: {:?}, \ + role: \"{:?}\", \ + error: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role(), + err + ); - debug!( - "Client {:?} talking to server {:?}", - self.addr, - server.address() - ); + if self.is_in_idle_transaction() { + should_continue_in_outer_loop = true; + break; + } else { + continue; + } + } + }; - server.sync_parameters(&self.server_parameters).await?; + // Before inserting this new connection, if we had only a single connection + // before, then it means that we have started a distributed transaction. + // At this point, we need to acquire the snapshot from the first server and + // use that snapshot for all the other servers. + if all_conns.len() == 1 { + let (first_server_key, first_conn) = + all_conns.iter_mut().next().unwrap(); + let first_server = &mut *first_conn.0; + + if !acquire_gid_and_snapshot(self, first_server_key, first_server) + .await? + { + break; + } + } - let mut initial_message = Some(message); + all_conns.insert(server_key, connection); + let is_distributed_xact = all_conns.len() > 1; + conn_opt = if is_distributed_xact { + let conn_opt = all_conns.get_mut(&server_key); + let conn = conn_opt.unwrap(); + let server = &mut *conn.0; + let address = &conn.1; + + debug!( + "Sending implicit BEGIN statement to server {} (in transparent mode with distributed transaction)", + address + ); + if !begin_distributed_xact(self, &server_key, server).await? { + break; + } + Some(conn) + } else { + all_conns.get_mut(&server_key) + } + } + let conn = conn_opt.unwrap(); + let server = &mut *conn.0; + let address = &conn.1; + + // Server is assigned to the client in case the client wants to + // cancel a query later. + server.claim(self.process_id, self.secret_key); + self.connected_to_server = true; + + // Update statistics + self.stats.active(); + + self.last_address_id = Some(address.id); + self.last_server_stats = Some(server.stats()); + self.last_server_key = Some(server_key); + + debug!( + "Client {:?} talking to server {:?}", + self.addr, + server.address() + ); - let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { - 0 => tokio::time::Duration::MAX, - timeout => tokio::time::Duration::from_millis(timeout), - }; + server.sync_parameters(&self.server_parameters).await?; + } - // Transaction loop. Multiple queries can be issued by the client here. - // The connection belongs to the client until the transaction is over, - // or until the client disconnects if we are in session mode. - // - // If the client is in session mode, no more custom protocol - // commands will be accepted. - loop { + let is_distributed_xact = all_conns.len() > 1; + let server_key = (query_router.shard().unwrap_or(0), query_router.role()); + let conn = all_conns.get_mut(&server_key).unwrap(); + let server = &mut *conn.0; + let address = &conn.1; // Only check if we should rewrite prepared statements // in session mode. In transaction mode, we check at the beginning of // each transaction. - if !self.transaction_mode { + if !self.is_transaction_mode() { prepared_statements_enabled = get_prepared_statements(); } @@ -1197,7 +1388,7 @@ where Ok(_) => (), Err(err) => { pool.ban( - &address, + address, BanReason::MessageSendFailed, Some(&self.stats), ); @@ -1211,56 +1402,6 @@ where prepared_statement = None; } - let mut message = match initial_message { - None => { - trace!("Waiting for message inside transaction or in session mode"); - - // This is not an initial message so discard the initial_parsed_ast - initial_parsed_ast.take(); - - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { - Ok(Ok(message)) => message, - Ok(Err(err)) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - self.stats.disconnect(); - server.checkin_cleanup().await?; - - return Err(err); - } - Err(_) => { - // Client idle in transaction timeout - error_response(&mut self.write, "idle transaction timeout").await?; - error!( - "Client idle in transaction timeout: \ - {{ \ - pool_name: {}, \ - username: {}, \ - shard: {:?}, \ - role: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role() - ); - - break; - } - } - } - - Some(message) => { - initial_message = None; - message - } - }; - // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. @@ -1268,33 +1409,31 @@ where // This reads the first byte without advancing the internal pointer and mutating the bytes let code = *message.get(0).unwrap() as char; - trace!("Message: {}", code); + trace!("client Message: {}", code); match code { // Query 'Q' => { - if query_router.query_parser_enabled() { - // We don't want to parse again if we already parsed it as the initial message - let ast = match initial_parsed_ast { - Some(_) => Some(initial_parsed_ast.take().unwrap()), - None => match query_router.parse(&message) { - Ok(ast) => Some(ast), - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - None - } - }, - }; - - if let Some(ast) = ast { + let mut ast = None; + if is_distributed_xact || query_router.query_parser_enabled() { + ast = parse_ast( + &mut initial_parsed_ast, + &query_router, + &message, + &client_identifier, + ); + + ast = if let Some(ast) = ast { let plugin_result = query_router.execute_plugins(&ast).await; match plugin_result { Ok(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; + error_response_with_state( + &mut self.write, + &error, + self.xact_info.state(), + ) + .await?; continue; } @@ -1305,30 +1444,64 @@ where _ => (), }; + + if is_distributed_xact { + if set_commit_or_abort_statement(self, &ast) { + break; + } + } + + Some(ast) + } else { + None } } - debug!("Sending query to server"); + debug!("Sending query to server (in Query mode)"); + + let server_was_in_transaction = server.in_transaction(); self.send_and_receive_loop( code, Some(&message), server, - &address, + address, &pool, &self.stats.clone(), ) .await?; - if !server.in_transaction() { + if server.in_transaction() { + // If the server was not in transaction and now it is, we need to store the + // begin statement. The begin statement is used if/when contacting another + // server in the same transaction. + if !server_was_in_transaction { + if ast.is_none() { + // We don't want to parse again if we already parsed it as the initial message + ast = parse_ast( + &mut initial_parsed_ast, + &query_router, + &message, + &client_identifier, + ); + } + assert!(ast.is_some()); + let ast_vec = ast.unwrap(); + assert_eq!(ast_vec.len(), 1); + + initialize_xact_info(self, server, &ast_vec[0]); + } + } else { // Report transaction executed statistics. self.stats.transaction(); server .stats() .transaction(&self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction mode. + // Release server back to the pool if we are in transaction or transparent modes. // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode && !server.in_copy_mode() { + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { self.stats.idle(); break; @@ -1422,11 +1595,16 @@ where // Sync // Frontend (client) is asking for the query result now. 'S' => { - debug!("Sending query to server"); + debug!("Sending query to server (in Sync mode)"); match plugin_output { Some(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; + error_response_with_state( + &mut self.write, + &error, + self.xact_info.state(), + ) + .await?; plugin_output = None; self.buffer.clear(); continue; @@ -1464,7 +1642,7 @@ where code, None, server, - &address, + address, &pool, &self.stats.clone(), ) @@ -1478,9 +1656,11 @@ where .stats() .transaction(&self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction mode. + // Release server back to the pool if we are in transaction or transparent modes. // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode && !server.in_copy_mode() { + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { break; } } @@ -1493,7 +1673,7 @@ where // Want to limit buffer size if self.buffer.len() > 8196 { // Forward the data to the server, - self.send_server_message(server, &self.buffer, &address, &pool) + self.send_server_message(server, &self.buffer, address, &pool) .await?; self.buffer.clear(); } @@ -1505,14 +1685,14 @@ where // We may already have some copy data in the buffer, add this message to buffer self.buffer.put(&message[..]); - self.send_server_message(server, &self.buffer, &address, &pool) + self.send_server_message(server, &self.buffer, address, &pool) .await?; // Clear the buffer self.buffer.clear(); let response = self - .receive_server_message(server, &address, &pool, &self.stats.clone()) + .receive_server_message(server, address, &pool, &self.stats.clone()) .await?; match write_all_flush(&mut self.write, &response).await { @@ -1529,9 +1709,9 @@ where .stats() .transaction(self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction mode. + // Release server back to the pool if we are in transaction or transparent modes. // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { + if self.is_transaction_mode() || self.is_transparent_mode() { break; } } @@ -1545,16 +1725,34 @@ where } } - // The server is no longer bound to us, we can't cancel it's queries anymore. - debug!("Releasing server back into the pool"); + if should_continue_in_outer_loop { + continue; + } + + distributed_commit_or_abort(self, &mut all_conns).await?; + + // Reset transaction state for safety reasons. Even if this state will be reset before + // the next transaction, this dirty state could be seen in-between here and there. + reset_client_xact(self); - server.checkin_cleanup().await?; + debug!("Releasing servers back into the pool"); - if prepared_statements_enabled { - server.maintain_cache().await?; + for (_, mut conn) in all_conns { + let server = &mut *conn.0; + let address = &conn.1; + + // The server is no longer bound to us, we can't cancel it's queries anymore. + debug!("Releasing server back into the pool: {}", address); + + server.checkin_cleanup().await?; + + if prepared_statements_enabled { + server.maintain_cache().await?; + } + + server.stats().idle(); } - server.stats().idle(); self.connected_to_server = false; self.release(); @@ -1571,7 +1769,7 @@ where error_response( &mut self.write, &format!( - "No pool configured for database: {}, user: {}", + "No pool configured for database: {}, user: {} (in get_pool)", self.pool_name, self.username ), ) @@ -1802,6 +2000,40 @@ where } } } + + pub fn is_transaction_mode(&self) -> bool { + self.client_pool_mode == PoolMode::Transaction + } + + pub fn is_transparent_mode(&self) -> bool { + self.client_pool_mode == PoolMode::Transparent + } + + pub fn is_in_idle_transaction(&self) -> bool { + self.xact_info.is_idle() + } +} + +fn parse_ast( + initial_parsed_ast: &mut Option>, + query_router: &QueryRouter, + message: &BytesMut, + client_identifier: &ClientIdentifier, +) -> Option> { + // We don't want to parse again if we already parsed it as the initial message + match *initial_parsed_ast { + Some(_) => Some(initial_parsed_ast.take().unwrap()), + None => match query_router.parse(message) { + Ok(ast) => Some(ast), + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + None + } + }, + } } impl Drop for Client { diff --git a/src/client_xact.rs b/src/client_xact.rs new file mode 100644 index 00000000..d64c8a0a --- /dev/null +++ b/src/client_xact.rs @@ -0,0 +1,581 @@ +use crate::client::Client; +use crate::errors::Error; +use crate::query_messages::{ErrorInfo, ErrorResponse, Message}; +use bytes::BytesMut; +/// Handle clients by pretending to be a PostgreSQL server. +use chrono::NaiveDateTime; +use futures::future::join_all; +use itertools::Either; +use log::{debug, warn}; +use sqlparser::ast::{Statement, TransactionAccessMode, TransactionMode}; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::config::{Address, Role}; +use crate::messages::*; +use crate::server::Server; +use crate::server_xact::*; + +/// DistributedPrepareResult is an accumulator for the results of the 'PREPARE TRANSACTION's. +#[derive(Debug, Default)] +struct DistributedPrepareResult { + max_prepare_timestamp: NaiveDateTime, +} + +impl DistributedPrepareResult { + /// Returns the maximum prepare timestamp of all servers. + pub fn get_max_prepare_timestamp(&self) -> NaiveDateTime { + self.max_prepare_timestamp + } + + /// Accumulates the results of a 'PREPARE TRANSACTION'. + /// The result is true if the server has a 'PREPARE TRANSACTION' timestamp. + pub fn accumulate(&mut self, server: &Server) -> bool { + let prep_timestamp = server.transaction_metadata().get_prepared_timestamp(); + if prep_timestamp.is_none() { + false + } else { + self.max_prepare_timestamp = + std::cmp::max(self.max_prepare_timestamp, prep_timestamp.unwrap()); + true + } + } +} + +/// This function starts a distributed transaction by sending a BEGIN statement to the first server. +/// It is called on the first server, as soon as client wants to interact with another server, +/// which hints that the client wants to start a distributed transaction. +pub async fn begin_distributed_xact( + clnt: &mut Client, + server_key: &(usize, Option), + server: &mut Server, +) -> Result +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let begin_stmt = clnt.xact_info.get_begin_statement(); + assert!(begin_stmt.is_some()); + if let Some(err) = query_server(clnt, server, &begin_stmt.unwrap().to_string()).await? { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + return Ok(false); + } + + if clnt.xact_info.params.is_repeatable_read_or_higher() { + // If we are in a repeatable read or serializable transaction, we need to use the + // snapshot we acquired from the first server. + assert!(clnt.xact_info.get_snapshot().is_some()); + let snapshot = clnt.xact_info.get_snapshot().unwrap(); + + debug!( + "Assigning snapshot ('{}') to server {}", + snapshot, + server.address(), + ); + + let snapshot_res = assign_xact_snapshot(server, &snapshot).await?; + set_post_query_state(clnt, server); + + if let Some(err) = snapshot_res { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + return Ok(false); + }; + } + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + assert!(clnt.xact_info.get_xact_gid().is_some()); + let gid = clnt.xact_info.get_xact_gid().unwrap(); + + debug!("Assigning GID ('{}') to server {}", gid, server.address(),); + + let gid_res = assign_xact_gid(server, &gen_server_specific_gid(server_key, &gid)).await?; + set_post_query_state(clnt, server); + if let Some(err) = gid_res { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + return Ok(false); + } + + Ok(true) +} + +/// This functions generates a GID for the current transaction and sends it to the server. +/// Also, if the transaction is repeatable read or higher, it acquires a snapshot from the server. +pub async fn acquire_gid_and_snapshot( + clnt: &mut Client, + server_key: &(usize, Option), + server: &mut Server, +) -> Result +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + assert!(clnt.xact_info.get_xact_gid().is_none()); + let gid = generate_xact_gid(clnt); + + debug!( + "Acquiring GID ('{}') and snapshot from server {}", + gid, + server.address(), + ); + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + let gid_res = assign_xact_gid(server, &gen_server_specific_gid(server_key, &gid)).await?; + set_post_query_state(clnt, server); + if let Some(err) = gid_res { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + return Ok(false); + } + clnt.xact_info.set_xact_gid(Some(gid)); + + if clnt.xact_info.params.is_repeatable_read_or_higher() { + // If we are in a repeatable read or serializable transaction, we need to acquire a + // snapshot from the server. + let snapshot_res = acquire_xact_snapshot(server).await?; + set_post_query_state(clnt, server); + + match snapshot_res { + Either::Left(snapshot) => { + debug!( + "Got first server snapshot: {} (on {})", + snapshot, + server.address() + ); + clnt.xact_info.set_snapshot(Some(snapshot)); + } + Either::Right(err) => { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + return Ok(false); + } + } + } + Ok(true) +} + +/// Generates a random GID (i.e., Global transaction ID) for a transaction. +fn generate_xact_gid(clnt: &Client) -> String { + format!( + "txn_{}_{}", + clnt.addr.to_string(), + Uuid::new_v4().to_string() + ) +} + +/// Generates a server-specific GID for a transaction. We need this, because it's possible that +/// multiple servers might actually be the same server (which commonly happens in testing). +fn gen_server_specific_gid(server_key: &(usize, Option), gid: &str) -> String { + format!("{}_{}", server_key.0, gid) +} + +/// Assigns the transaction state based on the state of all servers. +pub fn assign_client_transaction_state( + clnt: &mut Client, + all_conns: &HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, +) { + clnt.xact_info.set_state(if all_conns.is_empty() { + // if there's no server, we're in idle mode. + TransactionState::Idle + } else { + if is_any_server_in_failed_xact(all_conns) { + // if any server is in failed transaction, we're in failed transaction. + TransactionState::InFailedTransaction + } else { + // if we have at least one server and it is in a transaction, we're in a transaction. + TransactionState::InTransaction + } + }); +} + +/// Returns true if any server is in a failed transaction. +fn is_any_server_in_failed_xact( + all_conns: &HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, +) -> bool { + all_conns + .iter() + .any(|(_, conn)| in_failed_transaction(&*conn.0)) +} + +/// This function initializes the transaction parameters based on the first BEGIN statement. +pub fn initialize_xact_info( + clnt: &mut Client, + server: &mut Server, + begin_stmt: &Statement, +) { + if let Statement::StartTransaction { modes } = begin_stmt { + // This is the first BEGIN statement. We need + clnt.xact_info.set_begin_statement(Some(begin_stmt.clone())); + + // Initialize transaction parameters using the server's default. + clnt.xact_info.params = server_default_transaction_parameters(server); + for mode in modes { + match mode { + TransactionMode::AccessMode(access_mode) => { + clnt.xact_info.params.set_read_only(match access_mode { + TransactionAccessMode::ReadOnly => true, + TransactionAccessMode::ReadWrite => false, + }); + } + TransactionMode::IsolationLevel(isolation_level) => { + clnt.xact_info + .params + .set_isolation_level(isolation_level.clone()); + } + } + } + debug!( + "Transaction paramaters after the first BEGIN statement: {:?}", + clnt.xact_info.params + ); + + // Set the transaction parameters on the first server. + server.transaction_metadata_mut().params = clnt.xact_info.params.clone(); + } else { + // If we were not in a transaction and the first statement is + // not a BEGIN, then it's an irrecovable error. + assert!(false); + } +} + +/// This function performs a distribted abort/commit if necessary, and also resets the transaction +/// state. This is suppoed to be called before exiting the transaction loop. At that point, if +/// either an abort or commit statement is set, we need to perform a distributed abort/commit. This +/// is based on the logic that an abort or commit statement is only set if we are in a distributed +/// transaction and we observe a commit or abort statement sent to the server. That is where we exit +/// the transaction loop and expect this function to takeover and abort/commit the transaction. +pub async fn distributed_commit_or_abort( + clnt: &mut Client, + all_conns: &mut HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, +) -> Result<(), Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let dist_commit = clnt.xact_info.get_commit_statement(); + let dist_abort = clnt.xact_info.get_abort_statement(); + Ok(if dist_commit.is_some() || dist_abort.is_some() { + // if either a commit or abort statement is set, we should be in a distributed transaction. + assert!(all_conns.len() > 0); + + let is_chained = should_be_chained(dist_commit, dist_abort); + let dist_commit = dist_commit.map(|stmt| stmt.to_string()); + let mut dist_abort = dist_abort.map(|stmt| stmt.to_string()); + + // Report transaction executed statistics. + clnt.stats.transaction(); + + let mut is_distributed_commit_failed = false; + // We are in distributed transaction mode, and need to commit or abort on all servers. + if let Some(commit_stmt) = dist_commit { + // If two-phase commit was successful, we can send the COMMIT message to the client. + // Otherwise, we need to ROLLBACK on all servers. + if let Some(err) = distributed_commit(clnt, all_conns).await? { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + + // Currently, if a distributed commit fails, we send a ROLLBACK to all servers. + // However, this is different from how Postgres handles it. Postgres sends an + // error response to the client, and then does not accept any more queries from + // the client until the client explicitly sends a ROLLBACK. + dist_abort = Some("ROLLBACK".to_string()); + is_distributed_commit_failed = true; + } else { + custom_protocol_response_ok_with_state( + &mut clnt.write, + &commit_stmt, + TransactionState::Idle, + ) + .await?; + } + } + + if let Some(abort_stmt) = dist_abort { + let distributed_abort_res = distributed_abort(clnt, all_conns, &abort_stmt).await?; + if is_distributed_commit_failed { + // Nothing to do, as the error reponse is already sent before. + } else if let Some(err) = distributed_abort_res { + error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; + } else { + custom_protocol_response_ok_with_state( + &mut clnt.write, + &abort_stmt, + TransactionState::Idle, + ) + .await?; + } + } + + let is_all_servers_in_non_copy_mode = + all_conns.iter().all(|(_, conn)| !conn.0.in_copy_mode()); + + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (clnt.is_transaction_mode() || clnt.is_transparent_mode()) + && is_all_servers_in_non_copy_mode + { + clnt.stats.idle(); + } + + if is_chained { + let last_conn = all_conns + .get_mut(clnt.last_server_key.as_ref().unwrap()) + .unwrap(); + let last_server = &mut *last_conn.0; + + // TODO(MD): chained transaction should be implemented. + // Here, we need to start a local transaction on the last server. However, here is + // too late to start a transaction, as we are far from the transaction loop. We need to + // rearrange the code (or add more complicated control flow) to make it possible. + warn!( + "Chained transaction is not implemented yet. \ + The last server {} will NOT be in transaction.", + last_server.address() + ); + } + }) +} + +pub fn reset_client_xact(clnt: &mut Client) { + // Reset transaction state for safety reasons. + clnt.xact_info = Default::default(); +} + +async fn distributed_commit( + clnt: &mut Client, + all_conns: &mut HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, +) -> Result, Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + debug!("Committing distributed transaction."); + if is_any_server_in_failed_xact(all_conns) { + #[cfg(debug_assertions)] + all_conns.iter().for_each(|(server_key, conn)| { + let server = &*conn.0; + if in_failed_transaction(server) { + debug!( + "Server {} (with server_key: {:?}) is in failed transaction. Skipping commit.", + server.address(), + server_key, + ); + } + }); + + let err = ErrorInfo::new_brief( + "Error".to_string(), + "25P02".to_string(), + "Cannot commit a transaction that is in failed state.".to_string(), + ); + + return Ok(Some(ErrorResponse::from(err))); + } + let res = distributed_prepare(clnt, all_conns).await?; + if res.is_right() { + return Ok(res.right()); + } + let res = res.left().unwrap(); + + let commit_prepared_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + local_server_commit_prepared(server, res.get_max_prepare_timestamp()) + })) + .await; + + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + set_post_query_state(clnt, server); + }); + + for commit_prepared_res in commit_prepared_results { + if let Some(err) = commit_prepared_res? { + // For now, we just return the first error we encounter. + return Ok(Some(err)); + } + } + + Ok(None) +} + +/// After each interaction with the server, we need to set the transaction state based on the +/// server's state. +fn set_post_query_state(clnt: &mut Client, server: &mut Server) { + clnt.xact_info + .set_state(server.transaction_metadata().state()); +} + +async fn distributed_abort( + clnt: &mut Client, + all_conns: &mut HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + abort_stmt: &String, +) -> Result, Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + debug!("Aborting distributed transaction"); + let abort_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + server.query(abort_stmt) + })) + .await; + + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + set_post_query_state(clnt, server); + server + .stats() + .transaction(&clnt.server_parameters.get_application_name()); + }); + + for abort_res in abort_results { + if let Some(err) = abort_res? { + // For now, we just return the first error we encounter. + return Ok(Some(err)); + } + } + Ok(None) +} + +async fn distributed_prepare( + clnt: &mut Client, + all_conns: &mut HashMap< + (usize, Option), + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, +) -> Result, Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + // Apply 'PREPARE TRANSACTION' on all involved servers. + let prepare_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + set_post_query_state(clnt, server); + local_server_prepare_transaction(server) + })) + .await; + + // Update the client state based on the server state. + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + set_post_query_state(clnt, server); + }); + + // If there was any error, we need to abort the transaction. + let mut res = DistributedPrepareResult::default(); + for prepare_res in prepare_results { + if let Some(err) = prepare_res? { + // For now, we just return the first error we encounter. + return Ok(Either::Right(err)); + } + } + + // Otherwise, accumulate the results of 'PREPARE TRANSACTION'. + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + res.accumulate(&server); + }); + Ok(Either::Left(res)) +} + +/// This function is called when the client sends a query to the server without requiring an answer. +async fn query_server( + clnt: &mut Client, + server: &mut Server, + stmt: &str, +) -> Result, Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let qres = server.query(stmt).await?; + set_post_query_state(clnt, server); + Ok(qres) +} + +/// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort +/// statement on the client. +pub fn set_commit_or_abort_statement(clnt: &mut Client, ast: &Vec) -> bool { + if is_commit_statement(&ast) { + clnt.xact_info.set_commit_statement(Some(ast[0].clone())); + true + } else if is_abort_statement(&ast) { + clnt.xact_info.set_abort_statement(Some(ast[0].clone())); + true + } else { + false + } +} + +/// Returns true if the statement is a commit statement. +fn is_commit_statement(ast: &Vec) -> bool { + for statement in ast { + match *statement { + Statement::Commit { .. } => { + assert_eq!(ast.len(), 1); + return true; + } + _ => (), + } + } + false +} + +/// Returns true if the statement is an abort statement. +fn is_abort_statement(ast: &Vec) -> bool { + for statement in ast { + match *statement { + Statement::Rollback { .. } => { + assert_eq!(ast.len(), 1); + return true; + } + _ => (), + } + } + false +} + +/// Returns true if the commit or abort statement should be chained. +fn should_be_chained(dist_commit: Option<&Statement>, dist_abort: Option<&Statement>) -> bool { + dist_commit + .map(|stmt| match stmt { + Statement::Commit { chain } => *chain, + _ => false, + }) + .unwrap_or(false) + || dist_abort + .map(|stmt| match stmt { + Statement::Rollback { chain } => *chain, + _ => false, + }) + .unwrap_or(false) +} + +/// Send an error response to the client. +pub async fn error_response_stmt( + stream: &mut S, + err: &ErrorResponse, + t_state: TransactionState, +) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let mut err_bytes = BytesMut::new(); + err.encode(&mut err_bytes)?; + write_all_half(stream, &err_bytes).await?; + + ready_for_query_with_state(stream, t_state).await +} diff --git a/src/config.rs b/src/config.rs index 0404abc9..0bf9ab0b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -481,6 +481,9 @@ impl Default for General { /// - session: server is attached to the client. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum PoolMode { + #[serde(alias = "transparent", alias = "Transparent")] + Transparent, + #[serde(alias = "transaction", alias = "Transaction")] Transaction, @@ -491,6 +494,7 @@ pub enum PoolMode { impl ToString for PoolMode { fn to_string(&self) -> String { match *self { + PoolMode::Transparent => "transparent".to_string(), PoolMode::Transaction => "transaction".to_string(), PoolMode::Session => "session".to_string(), } @@ -867,15 +871,26 @@ pub struct Plugins { pub prewarmer: Option, } +pub trait Plugin { + fn is_enabled(&self) -> bool; +} + impl std::fmt::Display for Plugins { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn is_enabled(arg: Option<&T>) -> bool { + if arg.is_some() { + arg.unwrap().is_enabled() + } else { + false + } + } write!( f, "interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}", - self.intercept.is_some(), - self.table_access.is_some(), - self.query_logger.is_some(), - self.prewarmer.is_some(), + is_enabled(self.intercept.as_ref()), + is_enabled(self.table_access.as_ref()), + is_enabled(self.query_logger.as_ref()), + is_enabled(self.prewarmer.as_ref()), ) } } @@ -886,23 +901,47 @@ pub struct Intercept { pub queries: BTreeMap, } +impl Plugin for Intercept { + fn is_enabled(&self) -> bool { + self.enabled + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct TableAccess { pub enabled: bool, pub tables: Vec, } +impl Plugin for TableAccess { + fn is_enabled(&self) -> bool { + self.enabled + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct QueryLogger { pub enabled: bool, } +impl Plugin for QueryLogger { + fn is_enabled(&self) -> bool { + self.enabled + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct Prewarmer { pub enabled: bool, pub queries: Vec, } +impl Plugin for Prewarmer { + fn is_enabled(&self) -> bool { + self.enabled + } +} + impl Intercept { pub fn substitute(&mut self, db: &str, user: &str) { for (_, query) in self.queries.iter_mut() { diff --git a/src/lib.rs b/src/lib.rs index 6a8a1e36..2d64b216 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod admin; pub mod auth_passthrough; pub mod client; +pub mod client_xact; pub mod cmd_args; pub mod config; pub mod constants; @@ -12,9 +13,11 @@ pub mod mirrors; pub mod plugins; pub mod pool; pub mod prometheus; +pub mod query_messages; pub mod query_router; pub mod scram; pub mod server; +pub mod server_xact; pub mod sharding; pub mod stats; pub mod tls; diff --git a/src/messages.rs b/src/messages.rs index 07fe9317..b1b5cb87 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -12,6 +12,7 @@ use crate::config::get_config; use crate::errors::Error; use crate::constants::MESSAGE_TERMINATOR; +use crate::server_xact::TransactionState; use std::collections::HashMap; use std::ffi::CString; use std::fmt::{Display, Formatter}; @@ -115,6 +116,17 @@ pub fn simple_query(query: &str) -> BytesMut { /// Tell the client we're ready for another query. pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + ready_for_query_with_state(stream, TransactionState::Idle).await +} + +/// Tell the client we're ready for another query or not, based on the given transaction state. +pub async fn ready_for_query_with_state( + stream: &mut S, + t_state: TransactionState, +) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { @@ -124,7 +136,11 @@ where bytes.put_u8(b'Z'); bytes.put_i32(5); - bytes.put_u8(b'I'); // Idle + match t_state { + TransactionState::Idle => bytes.put_u8(b'I'), + TransactionState::InTransaction => bytes.put_u8(b'T'), + TransactionState::InFailedTransaction => bytes.put_u8(b'E'), + } write_all(stream, bytes).await } @@ -311,6 +327,24 @@ pub async fn custom_protocol_response_ok(stream: &mut S, message: &str) -> Re where S: tokio::io::AsyncWrite + std::marker::Unpin, { + custom_protocol_response_ok_with_state(stream, message, TransactionState::Idle).await +} + +/// Implements a response to our custom `SET SHARDING KEY` +/// and `SET SERVER ROLE` commands. +/// This tells the client we're ready for the next query or not, based on the state. +pub async fn custom_protocol_response_ok_with_state( + stream: &mut S, + message: &str, + t_state: TransactionState, +) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + debug!( + "Sending custom protocol response: {} at {:?} state.", + message, t_state + ); let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&format!("{}\0", message)[..]); @@ -322,18 +356,34 @@ where res.put_slice(&set_complete[..]); write_all_half(stream, &res).await?; - ready_for_query(stream).await + ready_for_query_with_state(stream, t_state).await } /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + error_response_with_state(stream, message, TransactionState::Idle).await +} + +/// Send a custom error message to the client. +/// Tell the client we are ready for the next query. Given the current transaction state, no +/// rollback is necessary if it's in the "idle" or "transaction" state (i.e., not already in the +/// rollback state). +/// Docs on error codes: . +pub async fn error_response_with_state( + stream: &mut S, + message: &str, + t_state: TransactionState, +) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { error_response_terminal(stream, message).await?; - ready_for_query(stream).await + ready_for_query_with_state(stream, t_state).await } /// Send a custom error message to the client. diff --git a/src/pool.rs b/src/pool.rs index 18123407..1e6501f5 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -92,7 +92,7 @@ impl From<&Address> for PoolIdentifier { /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { - /// Transaction or Session. + /// Transparent, Transaction or Session. pub pool_mode: PoolMode, /// Random or LeastOutstandingConnections. diff --git a/src/query_messages.rs b/src/query_messages.rs new file mode 100644 index 00000000..3e0d9bae --- /dev/null +++ b/src/query_messages.rs @@ -0,0 +1,456 @@ +/// Helper functions to send one-off protocol messages +/// and handle TcpStream (TCP socket). +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::errors::Error; + +pub type Oid = u32; + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct FieldDescription { + // the field name + name: String, + // the object ID of table, default to 0 if not a table + table_id: i32, + // the attribute number of the column, default to 0 if not a column from table + column_id: i16, + // the object ID of the data type + type_id: Oid, + // the size of data type, negative values denote variable-width types + type_size: i16, + // the type modifier + type_modifier: i32, + // the format code being used for the filed, will be 0 or 1 for now + format_code: i16, +} + +pub type PgWireResult = Result; + +/// Get null-terminated string, returns None when empty cstring read. +/// +/// Note that this implementation will also advance cursor by 1 after reading +/// empty cstring. This behaviour works for how postgres wire protocol handling +/// key-value pairs, which is ended by a single `\0` +pub(crate) fn get_cstring(buf: &mut BytesMut) -> Option { + let mut i = 0; + + // with bound check to prevent invalid format + while i < buf.remaining() && buf[i] != b'\0' { + i += 1; + } + + // i+1: include the '\0' + // move cursor to the end of cstring + let string_buf = buf.split_to(i + 1); + + if i == 0 { + None + } else { + Some(String::from_utf8_lossy(&string_buf[..i]).into_owned()) + } +} + +/// Put null-termianted string +/// +/// You can put empty string by giving `""` as input. +pub(crate) fn put_cstring(buf: &mut BytesMut, input: &str) { + buf.put_slice(input.as_bytes()); + buf.put_u8(b'\0'); +} + +/// Try to read message length from buf, without actually move the cursor +pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Option { + if buf.remaining() >= 4 + offset { + Some((&buf[offset..4 + offset]).get_i32() as usize) + } else { + None + } +} + +/// Check if message_length matches and move the cursor to right position then +/// call the `decode_fn` for the body +pub(crate) fn decode_packet( + buf: &mut BytesMut, + offset: usize, + decode_fn: F, +) -> PgWireResult> +where + F: Fn(&mut BytesMut, usize) -> PgWireResult, +{ + if let Some(msg_len) = get_length(buf, offset) { + if buf.remaining() >= msg_len + offset { + buf.advance(offset + 4); + return decode_fn(buf, msg_len).map(|r| Some(r)); + } + } + + Ok(None) +} + +/// Define how message encode and decoded. +pub trait Message: Sized { + /// Return the type code of the message. In order to maintain backward + /// compatibility, `Startup` has no message type. + #[inline] + fn message_type() -> Option { + None + } + + /// Return the length of the message, including the length integer itself. + fn message_length(&self) -> usize; + + /// Encode body part of the message. + fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()>; + + /// Decode body part of the message. + fn decode_body(buf: &mut BytesMut, full_len: usize) -> PgWireResult; + + /// Default implementation for encoding message. + /// + /// Message type and length are encoded in this implementation and it calls + /// `encode_body` for remaining parts. + fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> { + if let Some(mt) = Self::message_type() { + buf.put_u8(mt); + } + + buf.put_i32(self.message_length() as i32); + self.encode_body(buf) + } + + /// Default implementation for decoding message. + /// + /// Message type and length are decoded in this implementation and it calls + /// `decode_body` for remaining parts. Return `None` if the packet is not + /// complete for parsing. + fn decode(buf: &mut BytesMut) -> PgWireResult> { + let offset = Self::message_type().is_some().into(); + + decode_packet(buf, offset, |buf, full_len| { + Self::decode_body(buf, full_len) + }) + } +} + +pub const MESSAGE_TYPE_BYTE_ROW_DESCRITION: u8 = b'T'; + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct RowDescription { + fields: Vec, +} + +impl RowDescription { + pub fn fields(&self) -> &[FieldDescription] { + &self.fields + } +} + +impl Message for RowDescription { + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_ROW_DESCRITION) + } + + fn message_length(&self) -> usize { + 4 + 2 + + self + .fields + .iter() + .map(|f| f.name.as_bytes().len() + 1 + 4 + 2 + 4 + 2 + 4 + 2) + .sum::() + } + + fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + buf.put_i16(self.fields.len() as i16); + + for field in &self.fields { + put_cstring(buf, &field.name); + buf.put_i32(field.table_id); + buf.put_i16(field.column_id); + buf.put_u32(field.type_id); + buf.put_i16(field.type_size); + buf.put_i32(field.type_modifier); + buf.put_i16(field.format_code); + } + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { + let fields_len = buf.get_i16(); + let mut fields = Vec::with_capacity(fields_len as usize); + + for _ in 0..fields_len { + let field = FieldDescription { + name: get_cstring(buf).unwrap_or_else(|| "".to_owned()), + table_id: buf.get_i32(), + column_id: buf.get_i16(), + type_id: buf.get_u32(), + type_size: buf.get_i16(), + type_modifier: buf.get_i32(), + format_code: buf.get_i16(), + }; + + fields.push(field); + } + + Ok(RowDescription { fields }) + } +} + +/// Data structure for postgresql wire protocol `DataRow` message. +/// +/// Data can be represented as text or binary format as specified by format +/// codes from previous `RowDescription` message. +#[derive(PartialEq, Eq, Debug, Default, Clone)] +pub struct DataRow { + fields: Vec>, +} + +impl DataRow { + pub fn fields(&self) -> &[Option] { + &self.fields + } +} + +pub const MESSAGE_TYPE_BYTE_DATA_ROW: u8 = b'D'; + +impl Message for DataRow { + #[inline] + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_DATA_ROW) + } + + fn message_length(&self) -> usize { + 4 + 2 + + self + .fields + .iter() + .map(|b| b.as_ref().map(|b| b.len() + 4).unwrap_or(4)) + .sum::() + } + + fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + buf.put_i16(self.fields.len() as i16); + for field in &self.fields { + if let Some(bytes) = field { + buf.put_i32(bytes.len() as i32); + buf.put_slice(bytes.as_ref()); + } else { + buf.put_i32(-1); + } + } + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _msg_len: usize) -> PgWireResult { + let field_count = buf.get_i16() as usize; + + let mut fields = Vec::with_capacity(field_count); + for _ in 0..field_count { + let field_len = buf.get_i32(); + if field_len >= 0 { + fields.push(Some(buf.split_to(field_len as usize).freeze())); + } else { + fields.push(None); + } + } + + Ok(DataRow { fields }) + } +} + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct QueryResponse { + row_desc: RowDescription, + data_rows: Vec, +} + +impl QueryResponse { + pub fn new(row_desc: RowDescription, data_rows: Vec) -> Self { + Self { + row_desc, + data_rows, + } + } + + pub fn row_desc(&self) -> &RowDescription { + &self.row_desc + } + + pub fn data_rows(&self) -> &[DataRow] { + &self.data_rows + } +} + +// Postgres error and notice message fields +// This part of protocol is defined in +// https://www.postgresql.org/docs/8.2/protocol-error-fields.html +#[derive(Debug, Default)] +pub struct ErrorInfo { + // severity can be one of `ERROR`, `FATAL`, or `PANIC` (in an error + // message), or `WARNING`, `NOTICE`, `DEBUG`, `INFO`, or `LOG` (in a notice + // message), or a localized translation of one of these. + severity: String, + // error code defined in + // https://www.postgresql.org/docs/current/errcodes-appendix.html + code: String, + // readable message + message: String, + // optional secondary message + detail: Option, + // optional suggestion for fixing the issue + hint: Option, + // Position: the field value is a decimal ASCII integer, indicating an error + // cursor position as an index into the original query string. + position: Option, + // Internal position: this is defined the same as the P field, but it is + // used when the cursor position refers to an internally generated command + // rather than the one submitted by the client + internal_position: Option, + // Internal query: the text of a failed internally-generated command. + internal_query: Option, + // Where: an indication of the context in which the error occurred. + where_context: Option, + // File: the file name of the source-code location where the error was + // reported. + file_name: Option, + // Line: the line number of the source-code location where the error was + // reported. + line: Option, + // Routine: the name of the source-code routine reporting the error. + routine: Option, +} + +impl ErrorInfo { + pub fn new( + severity: String, + code: String, + message: String, + detail: Option, + hint: Option, + position: Option, + internal_position: Option, + internal_query: Option, + where_context: Option, + file_name: Option, + line: Option, + routine: Option, + ) -> Self { + Self { + severity, + code, + message, + detail, + hint, + position, + internal_position, + internal_query, + where_context, + file_name, + line, + routine, + } + } + pub fn new_brief(severity: String, code: String, message: String) -> Self { + Self::new( + severity, code, message, None, None, None, None, None, None, None, None, None, + ) + } +} + +impl ErrorInfo { + fn into_fields(self) -> Vec<(u8, String)> { + let mut fields = Vec::with_capacity(11); + + fields.push((b'S', self.severity)); + fields.push((b'C', self.code)); + fields.push((b'M', self.message)); + if let Some(value) = self.detail { + fields.push((b'D', value)); + } + if let Some(value) = self.hint { + fields.push((b'H', value)); + } + if let Some(value) = self.position { + fields.push((b'P', value)); + } + if let Some(value) = self.internal_position { + fields.push((b'p', value)); + } + if let Some(value) = self.internal_query { + fields.push((b'q', value)); + } + if let Some(value) = self.where_context { + fields.push((b'W', value)); + } + if let Some(value) = self.file_name { + fields.push((b'F', value)); + } + if let Some(value) = self.line { + fields.push((b'L', value.to_string())); + } + if let Some(value) = self.routine { + fields.push((b'R', value)); + } + + fields + } +} + +impl From for ErrorResponse { + fn from(ei: ErrorInfo) -> ErrorResponse { + ErrorResponse { + fields: ei.into_fields(), + } + } +} + +/// postgres error response, sent from backend to frontend +#[derive(PartialEq, Eq, Debug, Default)] +pub struct ErrorResponse { + fields: Vec<(u8, String)>, +} + +pub const MESSAGE_TYPE_BYTE_ERROR_RESPONSE: u8 = b'E'; + +impl Message for ErrorResponse { + #[inline] + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_ERROR_RESPONSE) + } + + fn message_length(&self) -> usize { + 4 + self + .fields + .iter() + .map(|f| 1 + f.1.as_bytes().len() + 1) + .sum::() + + 1 + } + + fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + for (code, value) in &self.fields { + buf.put_u8(*code); + put_cstring(buf, value); + } + + buf.put_u8(b'\0'); + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { + let mut fields = Vec::new(); + loop { + let code = buf.get_u8(); + + if code == b'\0' { + return Ok(ErrorResponse { fields }); + } else { + let value = get_cstring(buf).unwrap_or_else(|| "".to_owned()); + fields.push((code, value)); + } + } + } +} diff --git a/src/server.rs b/src/server.rs index 70c8270d..9445a3bd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,7 +10,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::mem; use std::net::IpAddr; use std::sync::Arc; -use std::time::SystemTime; +use std::time::{Instant, SystemTime}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::TcpStream; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; @@ -24,7 +24,9 @@ use crate::messages::BytesMutReader; use crate::messages::*; use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; +use crate::query_messages::{ErrorResponse, Message}; use crate::scram::ScramSha256; +use crate::server_xact::*; use crate::stats::ServerStats; use std::io::Write; @@ -106,7 +108,7 @@ impl StreamInner { } #[derive(Copy, Clone)] -struct CleanupState { +pub(crate) struct CleanupState { /// If server connection requires RESET ALL before checkin because of set statement needs_cleanup_set: bool, @@ -131,7 +133,7 @@ impl CleanupState { self.needs_cleanup_prepare = true; } - fn reset(&mut self) { + pub(crate) fn reset(&mut self) { self.needs_cleanup_set = false; self.needs_cleanup_prepare = false; } @@ -154,12 +156,15 @@ static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { set.insert("TimeZone".to_string()); set.insert("standard_conforming_strings".to_string()); set.insert("application_name".to_string()); + for param in TRANSACTION_PARAMETERS.iter() { + set.insert(param.clone()); + } set }); #[derive(Debug, Clone)] pub struct ServerParameters { - parameters: HashMap, + pub(crate) parameters: HashMap, } impl Default for ServerParameters { @@ -183,6 +188,7 @@ impl ServerParameters { false, ); server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false); + TransactionParameters::set_default_server_parameters(&mut server_parameters); server_parameters } @@ -269,7 +275,7 @@ impl From<&ServerParameters> for BytesMut { pub struct Server { /// Server host, e.g. localhost, /// port, e.g. 5432, and role, e.g. primary or replica. - address: Address, + pub(crate) address: Address, /// Server TCP connection. stream: BufStream, @@ -278,14 +284,14 @@ pub struct Server { buffer: BytesMut, /// Server information the server sent us over on startup. - server_parameters: ServerParameters, + pub(crate) server_parameters: ServerParameters, /// Backend id and secret key used for query cancellation. process_id: i32, secret_key: i32, /// Is the server inside a transaction or idle. - in_transaction: bool, + pub(crate) transaction_metadata: TransactionMetaData, /// Is there more data for the client to read. data_available: bool, @@ -297,7 +303,7 @@ pub struct Server { bad: bool, /// If server connection requires reset statements before checkin - cleanup_state: CleanupState, + pub(crate) cleanup_state: CleanupState, /// Mapping of clients and servers used for query cancellation. client_server_map: ClientServerMap, @@ -505,7 +511,7 @@ impl Server { } }; - trace!("Message: {}", code); + trace!("server Message: {}", code); match code { // Authentication @@ -802,14 +808,14 @@ impl Server { } }; - let server = Server { + let mut server = Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), server_parameters, process_id, secret_key, - in_transaction: false, + transaction_metadata: Default::default(), in_copy_mode: false, data_available: false, bad: false, @@ -833,6 +839,9 @@ impl Server { prepared_statements: BTreeSet::new(), }; + // We want to make sure that all servers are operating on the same isolation level. + sync_given_parameter_keys(&mut server, &TRANSACTION_PARAMETERS).await?; + return Ok(server); } @@ -925,7 +934,7 @@ impl Server { let code = message.get_u8() as char; let _len = message.get_i32(); - trace!("Message: {}", code); + trace!("recv Message: {}", code); match code { // ReadyForQuery @@ -935,17 +944,17 @@ impl Server { match transaction_state { // In transaction. 'T' => { - self.in_transaction = true; + self.transaction_metadata.state = TransactionState::InTransaction; } // Idle, transaction over. 'I' => { - self.in_transaction = false; + self.transaction_metadata.state = TransactionState::Idle; } // Some error occurred, the transaction was rolled back. 'E' => { - self.in_transaction = true; + self.transaction_metadata.state = TransactionState::InFailedTransaction; } // Something totally unexpected, this is not a Postgres server we know. @@ -988,7 +997,7 @@ impl Server { // No great way to differentiate between set and set local // As a result, we will miss cases when set statements are used in transactions // This will reduce amount of reset statements sent - if !self.in_transaction { + if !self.in_transaction() { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_set = true; } @@ -1195,8 +1204,8 @@ impl Server { /// If the server is still inside a transaction. /// If the client disconnects while the server is in a transaction, we will clean it up. pub fn in_transaction(&self) -> bool { - debug!("Server in transaction: {}", self.in_transaction); - self.in_transaction + debug!("Server in transaction: {:?}", self.transaction_metadata); + !self.transaction_metadata.is_idle() } pub fn in_copy_mode(&self) -> bool { @@ -1239,21 +1248,7 @@ impl Server { pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> { let parameter_diff = self.server_parameters.compare_params(parameters); - if parameter_diff.is_empty() { - return Ok(()); - } - - let mut query = String::from(""); - - for (key, value) in parameter_diff { - query.push_str(&format!("SET {} TO '{}';", key, value)); - } - - let res = self.query(&query).await; - - self.cleanup_state.reset(); - - res + sync_given_parameter_key_values(self, ¶meter_diff).await } /// Indicate that this server connection cannot be re-used and must be discarded. @@ -1279,22 +1274,31 @@ impl Server { /// Execute an arbitrary query against the server. /// It will use the simple query protocol. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. - pub async fn query(&mut self, query: &str) -> Result<(), Error> { - debug!("Running `{}` on server {:?}", query, self.address); + pub async fn query(&mut self, query: &str) -> Result, Error> { + debug!("Running `{}` on server {}", query, self.address); let query = simple_query(query); self.send(&query).await?; + let query_start = Instant::now(); + loop { - let _ = self.recv(None).await?; + let mut response = self.recv(None).await?; + + if response[0] == b'E' { + let err = ErrorResponse::decode(&mut response)?.unwrap(); + return Ok(Some(err)); + } if !self.data_available { break; } } - Ok(()) + query_time_stats(self, query_start); + + Ok(None) } /// Perform any necessary cleanup before putting the server @@ -1334,6 +1338,8 @@ impl Server { warn!(target: "pgcat::server::cleanup", "Server returned while still in copy-mode"); } + self.transaction_metadata = Default::default(); + Ok(()) } @@ -1399,6 +1405,14 @@ impl Server { Ok(parse_query_message(&mut message).await?) } + + pub fn transaction_metadata(&self) -> &TransactionMetaData { + &self.transaction_metadata + } + + pub fn transaction_metadata_mut(&mut self) -> &mut TransactionMetaData { + &mut self.transaction_metadata + } } async fn parse_query_message(message: &mut BytesMut) -> Result, Error> { diff --git a/src/server_xact.rs b/src/server_xact.rs new file mode 100644 index 00000000..cb99643a --- /dev/null +++ b/src/server_xact.rs @@ -0,0 +1,544 @@ +/// Implementation of the PostgreSQL server (database) protocol. +/// Here we are pretending to the a Postgres client. +use bytes::Buf; +use chrono::NaiveDateTime; +use itertools::Either; +use log::{debug, warn}; +use once_cell::sync::Lazy; +use sqlparser::ast::{Statement, TransactionIsolationLevel}; +use std::collections::HashMap; +use std::time::Instant; + +use crate::errors::Error; +use crate::messages::*; +use crate::query_messages::{DataRow, ErrorResponse, Message, QueryResponse, RowDescription}; +use crate::server::{Server, ServerParameters}; + +/// The default transaction parameters that might be configured on the server. +pub static TRANSACTION_PARAMETERS: Lazy> = Lazy::new(|| { + let mut list = Vec::new(); + list.push("default_transaction_isolation".to_string()); + list.push("default_transaction_read_only".to_string()); + list.push("default_transaction_deferrable".to_string()); + list +}); + +/// The default transaction parameters that are either configured on the server or set by the +/// BEGIN statement. +#[derive(Debug, Clone)] +pub struct TransactionParameters { + isolation_level: TransactionIsolationLevel, + read_only: bool, + deferrable: bool, +} + +impl TransactionParameters { + pub fn new( + isolation_level: TransactionIsolationLevel, + read_only: bool, + deferrable: bool, + ) -> Self { + Self { + isolation_level, + read_only, + deferrable, + } + } + + pub fn get_isolation_level(&self) -> TransactionIsolationLevel { + self.isolation_level + } + + pub fn is_read_only(&self) -> bool { + self.read_only + } + + pub fn is_deferrable(&self) -> bool { + self.deferrable + } + + pub fn set_isolation_level(&mut self, isolation_level: TransactionIsolationLevel) { + self.isolation_level = isolation_level; + } + + pub fn set_read_only(&mut self, read_only: bool) { + self.read_only = read_only; + } + + pub fn set_deferrable(&mut self, deferrable: bool) { + self.deferrable = deferrable; + } + + pub fn is_serializable(&self) -> bool { + matches!( + self.get_isolation_level(), + TransactionIsolationLevel::Serializable + ) + } + + pub fn is_repeatable_read(&self) -> bool { + matches!( + self.get_isolation_level(), + TransactionIsolationLevel::RepeatableRead + ) + } + + pub fn is_repeatable_read_or_higher(&self) -> bool { + self.is_serializable() || self.is_repeatable_read() + } + + /// Sets the default transaction parameters on the given ServerParameters instance. + pub fn set_default_server_parameters(sparams: &mut ServerParameters) { + // TODO(MD): make these configurable + sparams.set_param( + "default_transaction_isolation".to_string(), + "read committed".to_string(), + false, + ); + sparams.set_param( + "default_transaction_read_only".to_string(), + "off".to_string(), + false, + ); + sparams.set_param( + "default_transaction_deferrable".to_string(), + "off".to_string(), + false, + ); + } +} + +impl Default for TransactionParameters { + fn default() -> Self { + Self::new(TransactionIsolationLevel::ReadCommitted, false, false) + } +} + +fn get_default_transaction_isolation(sparams: &ServerParameters) -> TransactionIsolationLevel { + // Can unwrap because we set it in the constructor + if let Some(isolation_level) = sparams.parameters.get("default_transaction_isolation") { + return match isolation_level.to_lowercase().as_str() { + "read committed" => TransactionIsolationLevel::ReadCommitted, + "repeatable read" => TransactionIsolationLevel::RepeatableRead, + "serializable" => TransactionIsolationLevel::Serializable, + "read uncommitted" => TransactionIsolationLevel::ReadUncommitted, + _ => TransactionIsolationLevel::ReadCommitted, + }; + } + TransactionIsolationLevel::ReadCommitted +} + +fn get_default_transaction_read_only(sparams: &ServerParameters) -> bool { + if let Some(is_readonly) = sparams.parameters.get("default_transaction_read_only") { + return !is_readonly.to_lowercase().eq("off"); + } + false +} + +fn get_default_transaction_deferrable(sparams: &ServerParameters) -> bool { + if let Some(deferrable) = sparams.parameters.get("default_transaction_deferrable") { + return !deferrable.to_lowercase().eq("off"); + } + false +} + +fn get_default_transaction_parameters(sparams: &ServerParameters) -> TransactionParameters { + TransactionParameters::new( + get_default_transaction_isolation(sparams), + get_default_transaction_read_only(sparams), + get_default_transaction_deferrable(sparams), + ) +} + +pub fn server_default_transaction_parameters(server: &Server) -> TransactionParameters { + get_default_transaction_parameters(&server.server_parameters) +} + +/// Sends some queries to the server to sync the given pramaters specified by 'keys'. +pub async fn sync_given_parameter_keys(server: &mut Server, keys: &[String]) -> Result<(), Error> { + let mut key_values = HashMap::new(); + for key in keys { + if let Some(value) = server.server_parameters.parameters.get(key) { + key_values.insert(key.clone(), value.clone()); + } + } + sync_given_parameter_key_values(server, &key_values).await +} + +/// Sends some queries to the server to sync the given pramaters specified by 'key_values'. +pub async fn sync_given_parameter_key_values( + server: &mut Server, + key_values: &HashMap, +) -> Result<(), Error> { + let mut query = String::from(""); + + for (key, value) in key_values { + query.push_str(&format!("SET {} TO '{}';", key, value)); + } + + let res = server.query(&query).await; + + server.cleanup_state.reset(); + + match res { + Ok(None) => Ok(()), + Ok(Some(err_res)) => { + warn!( + "Error while syncing parameters (was dropped): {:?}", + err_res + ); + Ok(()) + } + Err(err) => Err(err), + } +} + +/// Returnes true if the given server is in a failed transaction state. +pub fn in_failed_transaction(server: &Server) -> bool { + server.transaction_metadata.is_in_failed_transaction() +} + +/// The various states that a server transaction can be in. +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum TransactionState { + /// Server is idle. + Idle, + /// Server is in a transaction. + InTransaction, + /// Server is in a failed transaction. + InFailedTransaction, +} + +/// The metadata of a server transaction. +#[derive(Debug, Clone)] +pub struct TransactionMetaData { + pub(crate) state: TransactionState, + + xact_gid: Option, + snapshot: Option, + prepared_timestamp: Option, + + begin_statement: Option, + commit_statement: Option, + abort_statement: Option, + + pub params: TransactionParameters, +} + +impl TransactionMetaData { + pub fn set_state(&mut self, state: TransactionState) { + match self.state { + TransactionState::Idle => { + self.state = state; + } + TransactionState::InTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a transaction."); + } + _ => { + self.state = state; + } + }, + TransactionState::InFailedTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a failed transaction."); + } + TransactionState::InTransaction => { + warn!("Cannot go back to a transaction from a failed transaction.") + } + _ => { + self.state = state; + } + }, + } + } + + pub fn state(&self) -> TransactionState { + self.state + } + + pub fn is_idle(&self) -> bool { + self.state == TransactionState::Idle + } + + pub fn is_in_transaction(&self) -> bool { + self.state == TransactionState::InTransaction + } + + pub fn is_in_failed_transaction(&self) -> bool { + self.state == TransactionState::InFailedTransaction + } + + pub fn set_xact_gid(&mut self, xact_gid: Option) { + self.xact_gid = xact_gid; + } + + pub fn get_xact_gid(&self) -> Option { + self.xact_gid.clone() + } + + pub fn set_snapshot(&mut self, snapshot: Option) { + self.snapshot = snapshot; + } + + pub fn get_snapshot(&self) -> Option { + self.snapshot.clone() + } + + pub fn set_prepared_timestamp(&mut self, prepared_timestamp: Option) { + self.prepared_timestamp = prepared_timestamp; + } + + pub fn get_prepared_timestamp(&self) -> Option { + self.prepared_timestamp + } + + pub fn has_done_prepare_transaction(&self) -> bool { + self.prepared_timestamp.is_some() + } + + pub fn set_begin_statement(&mut self, begin_statement: Option) { + self.begin_statement = begin_statement; + } + + pub fn get_begin_statement(&self) -> Option<&Statement> { + self.begin_statement.as_ref() + } + + pub fn set_commit_statement(&mut self, commit_statement: Option) { + self.commit_statement = commit_statement; + } + + pub fn get_commit_statement(&self) -> Option<&Statement> { + self.commit_statement.as_ref() + } + + pub fn set_abort_statement(&mut self, abort_statement: Option) { + self.abort_statement = abort_statement; + } + + pub fn get_abort_statement(&self) -> Option<&Statement> { + self.abort_statement.as_ref() + } +} + +impl Default for TransactionMetaData { + fn default() -> Self { + Self { + state: TransactionState::Idle, + xact_gid: None, + snapshot: None, + prepared_timestamp: None, + begin_statement: None, + commit_statement: None, + abort_statement: None, + params: TransactionParameters::default(), + } + } +} + +/// Represents a read-write conflict. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct RWConflict { + pub source_gid: String, + pub gid_in: String, + pub gid_out: String, +} + +impl RWConflict { + pub fn new(source_gid: String, gid_in: String, gid_out: String) -> Self { + Self { + source_gid, + gid_in, + gid_out, + } + } +} + +/// Given the query start time, this function registers the query execution time to the server +/// stats. +pub fn query_time_stats(server: &mut Server, query_start: Instant) { + server.stats().query( + Instant::now().duration_since(query_start).as_millis() as u64, + &server.server_parameters.get_application_name(), + ); +} + +/// Execute an arbitrary query against the server. +/// It will use the simple query protocol. +/// Result will be returned, so this is useful for things like `SELECT`. +async fn query_with_response( + server: &mut Server, + query: &str, +) -> Result, Error> { + debug!( + "Running `{}` on server {} and capture response.", + query, server.address + ); + + let query = simple_query(query); + + server.send(&query).await?; + + let query_start = Instant::now(); + // Read all data the server has to offer, which can be multiple messages + // buffered in 8196 bytes chunks. + let mut response = server.recv(None).await?; + + let query_response = match response[0] { + b'T' => { + let row_desc = RowDescription::decode(&mut response)?.unwrap(); + + let mut data_rows = Vec::new(); + + loop { + if response.remaining() == 0 { + if server.is_data_available() { + response = server.recv(None).await?; + } else { + break; + } + } + if response[0] == b'C' { + break; + } + + let data_row = DataRow::decode(&mut response)?.unwrap(); + data_rows.push(data_row); + } + + QueryResponse::new(row_desc, data_rows) + } + + b'E' => { + let err = ErrorResponse::decode(&mut response)?.unwrap(); + return Ok(Either::Right(err)); + } + + _ => return Err(Error::ServerError), + }; + + query_time_stats(server, query_start); + + Ok(Either::Left(query_response)) +} + +/// Captures the snapshot from the server. +pub async fn acquire_xact_snapshot( + server: &mut Server, +) -> Result, Error> { + let qres = query_with_response(server, "select pg_export_snapshot()").await?; + + if qres.is_right() { + return Ok(Either::Right(qres.right().unwrap())); + } + + let qres = qres.left().unwrap(); + + let qres_rows: &[DataRow] = qres.data_rows(); + assert!(qres.row_desc().fields().len() == 1); + assert!(qres_rows.len() == 1); + if let Some(snapshot) = qres_rows[0].fields().get(0).unwrap() { + let snapshot = std::str::from_utf8(&snapshot).unwrap().to_string(); + + debug!("Got snapshot: {}", snapshot); + + server + .transaction_metadata + .set_snapshot(Some(snapshot.clone())); + + Ok(Either::Left(snapshot)) + } else { + Err(Error::BadQuery( + "Could not get snapshot from server".to_string(), + )) + } +} + +/// Sets the snapshot to the server (based on a previous snapshot acquired by the first server). +pub async fn assign_xact_snapshot( + server: &mut Server, + snapshot: &str, +) -> Result, Error> { + server + .query(&format!("set transaction snapshot '{snapshot}'")) + .await +} + +/// Sets the GID on the server. If we are in serializable mode, we need to register the GID to +/// the remote postgres instance, too. +pub async fn assign_xact_gid( + server: &mut Server, + gid: &str, +) -> Result, Error> { + server + .transaction_metadata + .set_xact_gid(Some(gid.to_string())); + Ok(None) +} + +pub async fn local_server_prepare_transaction( + server: &mut Server, +) -> Result, Error> { + debug!( + "Called local_server_prepare_transaction on {}", + server.address, + ); + + let xact_gid = server.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery(format!( + "There is no GID assigned to the current transaction while it's requested to be \ + prepared to commit on the server ({}).", + server.address() + ))); + } + let xact_gid = xact_gid.unwrap(); + + if let Some(prep_time) = server.transaction_metadata.get_prepared_timestamp() { + return Err(Error::BadQuery(format!( + "The server ({}) was prepared in the past: {} (with gid: {})", + server.address(), + prep_time, + xact_gid, + ))); + } + + let qres = server + .query(&format!("PREPARE TRANSACTION '{}'", xact_gid)) + .await?; + if qres.is_some() { + return Ok(qres); + } + + Ok(None) +} + +pub async fn local_server_commit_prepared( + server: &mut Server, + commit_ts: NaiveDateTime, +) -> Result, Error> { + debug!( + "Called local_server_commit_prepared on {} with commit_ts: {:?}", + server.address, commit_ts + ); + + let xact_gid = server.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery( + "The current connection is not attached to a \ + transaction while it's requested to be prepared to commit." + .to_string(), + )); + } + let xact_gid = xact_gid.unwrap(); + + let qres = server + .query(&format!("COMMIT PREPARED '{}'", xact_gid)) + .await?; + if qres.is_some() { + return Ok(qres); + } + + Ok(None) +} From 743ef39c3600ad08eb60d6dad46ccf9c2480a4b8 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 02:39:24 -0700 Subject: [PATCH 02/15] Reverted `Docker` file changes. --- Dockerfile | 17 ++++------------- Dockerfile.dev | 19 +++++-------------- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/Dockerfile b/Dockerfile index 57c4ec11..ec29182a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,19 +1,10 @@ -FROM ubuntu:23.04 AS builder - -RUN apt-get update && \ - apt-get install -y build-essential curl - -# Get Rust -RUN curl https://sh.rustup.rs -sSf | bash -s -- -y - -ENV PATH="/root/.cargo/bin:${PATH}" - +FROM rust:1 AS builder COPY . /app WORKDIR /app -RUN /root/.cargo/bin/cargo build --release +RUN cargo build --release -FROM ubuntu:23.04 -COPY --from=builder /app/target/release /usr/bin/ +FROM debian:bullseye-slim +COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml WORKDIR /etc/pgcat ENV RUST_LOG=info diff --git a/Dockerfile.dev b/Dockerfile.dev index 1b3cd3c1..a4b8d0ed 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -1,16 +1,7 @@ -FROM ubuntu:23.04 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef RUN apt-get update && \ - apt-get install -y build-essential curl - -# Get Rust -RUN curl https://sh.rustup.rs -sSf | bash -s -- -y - -ENV PATH="/root/.cargo/bin:${PATH}" - -# We only pay the installation cost once, -# it will be cached from the second build onwards -RUN cargo install cargo-chef + apt-get install -y build-essential WORKDIR /app @@ -24,10 +15,10 @@ COPY --from=planner /app/recipe.json recipe.json RUN cargo chef cook --release --recipe-path recipe.json # Build application COPY . . -RUN cargo build --release +RUN cargo build -FROM ubuntu:23.04 -COPY --from=builder /app/target/release /usr/bin/ +FROM debian:bookworm-slim +COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml WORKDIR /etc/pgcat ENV RUST_LOG=info From d3ffdeaf1a158c16c4a335a87cb3bdc0f41048ed Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 02:52:57 -0700 Subject: [PATCH 03/15] Reverted the changes related to PR #536. --- src/config.rs | 43 ++++--------------------------------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/src/config.rs b/src/config.rs index 0bf9ab0b..349aed58 100644 --- a/src/config.rs +++ b/src/config.rs @@ -871,26 +871,15 @@ pub struct Plugins { pub prewarmer: Option, } -pub trait Plugin { - fn is_enabled(&self) -> bool; -} - impl std::fmt::Display for Plugins { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - fn is_enabled(arg: Option<&T>) -> bool { - if arg.is_some() { - arg.unwrap().is_enabled() - } else { - false - } - } write!( f, "interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}", - is_enabled(self.intercept.as_ref()), - is_enabled(self.table_access.as_ref()), - is_enabled(self.query_logger.as_ref()), - is_enabled(self.prewarmer.as_ref()), + self.intercept.is_some(), + self.table_access.is_some(), + self.query_logger.is_some(), + self.prewarmer.is_some(), ) } } @@ -901,47 +890,23 @@ pub struct Intercept { pub queries: BTreeMap, } -impl Plugin for Intercept { - fn is_enabled(&self) -> bool { - self.enabled - } -} - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct TableAccess { pub enabled: bool, pub tables: Vec, } -impl Plugin for TableAccess { - fn is_enabled(&self) -> bool { - self.enabled - } -} - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct QueryLogger { pub enabled: bool, } -impl Plugin for QueryLogger { - fn is_enabled(&self) -> bool { - self.enabled - } -} - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct Prewarmer { pub enabled: bool, pub queries: Vec, } -impl Plugin for Prewarmer { - fn is_enabled(&self) -> bool { - self.enabled - } -} - impl Intercept { pub fn substitute(&mut self, db: &str, user: &str) { for (_, query) in self.queries.iter_mut() { From cd9889155232b8432a5e6b6a57d5ced56aef4a7a Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 02:56:40 -0700 Subject: [PATCH 04/15] Reverted `Docker` file changes. --- Dockerfile | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index ec29182a..f2d58062 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,13 @@ -FROM rust:1 AS builder +FROM rust:1-slim-bookworm AS builder + +RUN apt-get update && \ + apt-get install -y build-essential + COPY . /app WORKDIR /app RUN cargo build --release -FROM debian:bullseye-slim +FROM debian:bookworm-slim COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml WORKDIR /etc/pgcat From 8e5552d4816c1bd2d24f1bef0dc2c8e3c58e9403 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 02:59:33 -0700 Subject: [PATCH 05/15] Reverted the change in the default `prewarmer`. --- pgcat.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgcat.toml b/pgcat.toml index ebf7a4f7..579b9bec 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -227,7 +227,7 @@ connect_timeout = 3000 [pool.sharded_db.plugins] [pools.sharded_db.plugins.prewarmer] -enabled = false +enabled = true queries = [ "SELECT pg_prewarm('pgbench_accounts')", ] From d315dea665e723d48c61af1ec8474d6144424b79 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 03:06:44 -0700 Subject: [PATCH 06/15] Removed prepared transaction timestamp from this PR. --- src/client_xact.rs | 42 +++--------------------------------------- src/server_xact.rs | 30 +----------------------------- 2 files changed, 4 insertions(+), 68 deletions(-) diff --git a/src/client_xact.rs b/src/client_xact.rs index d64c8a0a..427c65c9 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -2,8 +2,6 @@ use crate::client::Client; use crate::errors::Error; use crate::query_messages::{ErrorInfo, ErrorResponse, Message}; use bytes::BytesMut; -/// Handle clients by pretending to be a PostgreSQL server. -use chrono::NaiveDateTime; use futures::future::join_all; use itertools::Either; use log::{debug, warn}; @@ -16,32 +14,6 @@ use crate::messages::*; use crate::server::Server; use crate::server_xact::*; -/// DistributedPrepareResult is an accumulator for the results of the 'PREPARE TRANSACTION's. -#[derive(Debug, Default)] -struct DistributedPrepareResult { - max_prepare_timestamp: NaiveDateTime, -} - -impl DistributedPrepareResult { - /// Returns the maximum prepare timestamp of all servers. - pub fn get_max_prepare_timestamp(&self) -> NaiveDateTime { - self.max_prepare_timestamp - } - - /// Accumulates the results of a 'PREPARE TRANSACTION'. - /// The result is true if the server has a 'PREPARE TRANSACTION' timestamp. - pub fn accumulate(&mut self, server: &Server) -> bool { - let prep_timestamp = server.transaction_metadata().get_prepared_timestamp(); - if prep_timestamp.is_none() { - false - } else { - self.max_prepare_timestamp = - std::cmp::max(self.max_prepare_timestamp, prep_timestamp.unwrap()); - true - } - } -} - /// This function starts a distributed transaction by sending a BEGIN statement to the first server. /// It is called on the first server, as soon as client wants to interact with another server, /// which hints that the client wants to start a distributed transaction. @@ -383,11 +355,10 @@ where if res.is_right() { return Ok(res.right()); } - let res = res.left().unwrap(); let commit_prepared_results = join_all(all_conns.iter_mut().map(|(_, conn)| { let server = &mut *conn.0; - local_server_commit_prepared(server, res.get_max_prepare_timestamp()) + local_server_commit_prepared(server) })) .await; @@ -455,7 +426,7 @@ async fn distributed_prepare( (usize, Option), (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, -) -> Result, Error> +) -> Result, Error> where S: tokio::io::AsyncRead + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin, @@ -475,20 +446,13 @@ where }); // If there was any error, we need to abort the transaction. - let mut res = DistributedPrepareResult::default(); for prepare_res in prepare_results { if let Some(err) = prepare_res? { // For now, we just return the first error we encounter. return Ok(Either::Right(err)); } } - - // Otherwise, accumulate the results of 'PREPARE TRANSACTION'. - all_conns.iter_mut().for_each(|(_, conn)| { - let server = &mut *conn.0; - res.accumulate(&server); - }); - Ok(Either::Left(res)) + Ok(Either::Left(())) } /// This function is called when the client sends a query to the server without requiring an answer. diff --git a/src/server_xact.rs b/src/server_xact.rs index cb99643a..78e14fe5 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -1,7 +1,6 @@ /// Implementation of the PostgreSQL server (database) protocol. /// Here we are pretending to the a Postgres client. use bytes::Buf; -use chrono::NaiveDateTime; use itertools::Either; use log::{debug, warn}; use once_cell::sync::Lazy; @@ -216,7 +215,6 @@ pub struct TransactionMetaData { xact_gid: Option, snapshot: Option, - prepared_timestamp: Option, begin_statement: Option, commit_statement: Option, @@ -285,18 +283,6 @@ impl TransactionMetaData { self.snapshot.clone() } - pub fn set_prepared_timestamp(&mut self, prepared_timestamp: Option) { - self.prepared_timestamp = prepared_timestamp; - } - - pub fn get_prepared_timestamp(&self) -> Option { - self.prepared_timestamp - } - - pub fn has_done_prepare_transaction(&self) -> bool { - self.prepared_timestamp.is_some() - } - pub fn set_begin_statement(&mut self, begin_statement: Option) { self.begin_statement = begin_statement; } @@ -328,7 +314,6 @@ impl Default for TransactionMetaData { state: TransactionState::Idle, xact_gid: None, snapshot: None, - prepared_timestamp: None, begin_statement: None, commit_statement: None, abort_statement: None, @@ -495,15 +480,6 @@ pub async fn local_server_prepare_transaction( } let xact_gid = xact_gid.unwrap(); - if let Some(prep_time) = server.transaction_metadata.get_prepared_timestamp() { - return Err(Error::BadQuery(format!( - "The server ({}) was prepared in the past: {} (with gid: {})", - server.address(), - prep_time, - xact_gid, - ))); - } - let qres = server .query(&format!("PREPARE TRANSACTION '{}'", xact_gid)) .await?; @@ -516,12 +492,8 @@ pub async fn local_server_prepare_transaction( pub async fn local_server_commit_prepared( server: &mut Server, - commit_ts: NaiveDateTime, ) -> Result, Error> { - debug!( - "Called local_server_commit_prepared on {} with commit_ts: {:?}", - server.address, commit_ts - ); + debug!("Called local_server_commit_prepared on {}.", server.address); let xact_gid = server.transaction_metadata.get_xact_gid(); if xact_gid.is_none() { From e61615f619d5a06a46360fa919ff61cd84c52dac Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 22 Sep 2023 03:50:25 -0700 Subject: [PATCH 07/15] Follow the expanded `Result` style used in the code-base. --- src/query_messages.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/query_messages.rs b/src/query_messages.rs index 3e0d9bae..13e6d1f2 100644 --- a/src/query_messages.rs +++ b/src/query_messages.rs @@ -24,8 +24,6 @@ pub struct FieldDescription { format_code: i16, } -pub type PgWireResult = Result; - /// Get null-terminated string, returns None when empty cstring read. /// /// Note that this implementation will also advance cursor by 1 after reading @@ -73,9 +71,9 @@ pub(crate) fn decode_packet( buf: &mut BytesMut, offset: usize, decode_fn: F, -) -> PgWireResult> +) -> Result, Error> where - F: Fn(&mut BytesMut, usize) -> PgWireResult, + F: Fn(&mut BytesMut, usize) -> Result, { if let Some(msg_len) = get_length(buf, offset) { if buf.remaining() >= msg_len + offset { @@ -100,16 +98,16 @@ pub trait Message: Sized { fn message_length(&self) -> usize; /// Encode body part of the message. - fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()>; + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error>; /// Decode body part of the message. - fn decode_body(buf: &mut BytesMut, full_len: usize) -> PgWireResult; + fn decode_body(buf: &mut BytesMut, full_len: usize) -> Result; /// Default implementation for encoding message. /// /// Message type and length are encoded in this implementation and it calls /// `encode_body` for remaining parts. - fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> { + fn encode(&self, buf: &mut BytesMut) -> Result<(), Error> { if let Some(mt) = Self::message_type() { buf.put_u8(mt); } @@ -123,7 +121,7 @@ pub trait Message: Sized { /// Message type and length are decoded in this implementation and it calls /// `decode_body` for remaining parts. Return `None` if the packet is not /// complete for parsing. - fn decode(buf: &mut BytesMut) -> PgWireResult> { + fn decode(buf: &mut BytesMut) -> Result, Error> { let offset = Self::message_type().is_some().into(); decode_packet(buf, offset, |buf, full_len| { @@ -159,7 +157,7 @@ impl Message for RowDescription { .sum::() } - fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { buf.put_i16(self.fields.len() as i16); for field in &self.fields { @@ -175,7 +173,7 @@ impl Message for RowDescription { Ok(()) } - fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { + fn decode_body(buf: &mut BytesMut, _: usize) -> Result { let fields_len = buf.get_i16(); let mut fields = Vec::with_capacity(fields_len as usize); @@ -229,7 +227,7 @@ impl Message for DataRow { .sum::() } - fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { buf.put_i16(self.fields.len() as i16); for field in &self.fields { if let Some(bytes) = field { @@ -243,7 +241,7 @@ impl Message for DataRow { Ok(()) } - fn decode_body(buf: &mut BytesMut, _msg_len: usize) -> PgWireResult { + fn decode_body(buf: &mut BytesMut, _msg_len: usize) -> Result { let field_count = buf.get_i16() as usize; let mut fields = Vec::with_capacity(field_count); @@ -429,7 +427,7 @@ impl Message for ErrorResponse { + 1 } - fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> { + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { for (code, value) in &self.fields { buf.put_u8(*code); put_cstring(buf, value); @@ -440,7 +438,7 @@ impl Message for ErrorResponse { Ok(()) } - fn decode_body(buf: &mut BytesMut, _: usize) -> PgWireResult { + fn decode_body(buf: &mut BytesMut, _: usize) -> Result { let mut fields = Vec::new(); loop { let code = buf.get_u8(); From bfe83b1a3f7fc7a391cd56e7c802bba00d3e3291 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Tue, 26 Sep 2023 16:12:47 +0000 Subject: [PATCH 08/15] Fixed Query handling inside transaction loop. --- src/client.rs | 72 ++++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/src/client.rs b/src/client.rs index 16ee067e..a2e11a60 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1416,44 +1416,46 @@ where 'Q' => { let mut ast = None; if is_distributed_xact || query_router.query_parser_enabled() { - ast = parse_ast( - &mut initial_parsed_ast, - &query_router, - &message, - &client_identifier, - ); - - ast = if let Some(ast) = ast { - let plugin_result = query_router.execute_plugins(&ast).await; - - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response_with_state( - &mut self.write, - &error, - self.xact_info.state(), - ) - .await?; - continue; + ast = match query_router.parse(&message) { + Ok(ast) => { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response_with_state( + &mut self.write, + &error, + self.xact_info.state(), + ) + .await?; + continue; + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + + if is_distributed_xact { + if set_commit_or_abort_statement(self, &ast) { + break; + } } - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } - - _ => (), - }; - - if is_distributed_xact { - if set_commit_or_abort_statement(self, &ast) { - break; - } + Some(ast) + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + None } - - Some(ast) - } else { - None } } debug!("Sending query to server (in Query mode)"); From d7d91b6c1aee6b9c9014000b3122a536d8e08d82 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Sun, 1 Oct 2023 23:17:14 -0700 Subject: [PATCH 09/15] Improved the code in `client.rs` to minimize conflicts with upstream. --- src/client.rs | 201 +++++++++++++++++++++++++------------------------- 1 file changed, 99 insertions(+), 102 deletions(-) diff --git a/src/client.rs b/src/client.rs index a2e11a60..18fafced 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1012,6 +1012,8 @@ where query_router.update_pool_settings(pool.settings.clone()); + let mut is_initial_message = true; + let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { 0 => tokio::time::Duration::MAX, timeout => tokio::time::Duration::from_millis(timeout), @@ -1021,11 +1023,12 @@ where (usize, Option), (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), > = HashMap::new(); - let mut initial_message = Some(message); // Reset transaction state, as we are entering a new transaction loop. reset_client_xact(self); + // if we want to jump out of the transaction loop, but still want to `continue` in the + // custom protocol loop, we set this flag to true. let mut should_continue_in_outer_loop = false; // Transaction loop. Multiple queries can be issued by the client here. @@ -1035,66 +1038,6 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { - let mut message = match initial_message { - None => { - trace!("Waiting for message inside transaction or in session mode"); - - // This is not an initial message so discard the initial_parsed_ast - initial_parsed_ast.take(); - - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { - Ok(Ok(message)) => message, - Ok(Err(err)) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - self.stats.disconnect(); - for (_, mut conn) in all_conns { - let server = &mut *conn.0; - server.checkin_cleanup().await?; - } - - return Err(err); - } - Err(_) => { - // Client idle in transaction timeout - error_response_with_state( - &mut self.write, - "idle transaction timeout", - self.xact_info.state(), - ) - .await?; - error!( - "Client idle in transaction timeout: \ - {{ \ - pool_name: {}, \ - username: {}, \ - shard: {:?}, \ - role: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role() - ); - - break; - } - } - } - - Some(message) => { - initial_message = None; - message - } - }; - - assign_client_transaction_state(self, &all_conns); - if all_conns.is_empty() || self.is_transparent_mode() { let current_shard = query_router.shard(); @@ -1322,8 +1265,8 @@ where } } let conn = conn_opt.unwrap(); - let server = &mut *conn.0; let address = &conn.1; + let server = &mut *conn.0; // Server is assigned to the client in case the client wants to // cancel a query later. @@ -1346,11 +1289,14 @@ where server.sync_parameters(&self.server_parameters).await?; } + assign_client_transaction_state(self, &all_conns); + let is_distributed_xact = all_conns.len() > 1; let server_key = (query_router.shard().unwrap_or(0), query_router.role()); let conn = all_conns.get_mut(&server_key).unwrap(); let server = &mut *conn.0; let address = &conn.1; + // Only check if we should rewrite prepared statements // in session mode. In transaction mode, we check at the beginning of // each transaction. @@ -1402,6 +1348,59 @@ where prepared_statement = None; } + if !is_initial_message { + is_initial_message = false; + } else { + trace!("Waiting for message inside transaction or in session mode"); + + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); + + match tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) + .await + { + Ok(Ok(msg)) => message = msg, + Ok(Err(err)) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + self.stats.disconnect(); + for (_, mut conn) in all_conns { + let server = &mut *conn.0; + server.checkin_cleanup().await?; + } + + return Err(err); + } + Err(_) => { + // Client idle in transaction timeout + error_response_with_state( + &mut self.write, + "idle transaction timeout", + self.xact_info.state(), + ) + .await?; + error!( + "Client idle in transaction timeout: \ + {{ \ + pool_name: {}, \ + username: {}, \ + shard: {:?}, \ + role: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); + + break; + } + }; + }; + // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. @@ -1416,45 +1415,40 @@ where 'Q' => { let mut ast = None; if is_distributed_xact || query_router.query_parser_enabled() { - ast = match query_router.parse(&message) { - Ok(ast) => { - let plugin_result = query_router.execute_plugins(&ast).await; - - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response_with_state( - &mut self.write, - &error, - self.xact_info.state(), - ) - .await?; - continue; - } - - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } - - _ => (), - }; - - let _ = query_router.infer(&ast); - - if is_distributed_xact { - if set_commit_or_abort_statement(self, &ast) { - break; - } + // We don't want to parse again if we already parsed it as the initial message + ast = parse_ast( + &mut initial_parsed_ast, + &mut query_router, + &message, + &client_identifier, + ); + + if let Some(ast) = &ast { + let plugin_result = query_router.execute_plugins(ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response_with_state( + &mut self.write, + &error, + self.xact_info.state(), + ) + .await?; + continue; } - Some(ast) - } - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - None + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } + + _ => (), + }; + + if is_distributed_xact { + if set_commit_or_abort_statement(self, &ast) { + break; + } } } } @@ -1481,7 +1475,7 @@ where // We don't want to parse again if we already parsed it as the initial message ast = parse_ast( &mut initial_parsed_ast, - &query_router, + &mut query_router, &message, &client_identifier, ); @@ -2018,7 +2012,7 @@ where fn parse_ast( initial_parsed_ast: &mut Option>, - query_router: &QueryRouter, + query_router: &mut QueryRouter, message: &BytesMut, client_identifier: &ClientIdentifier, ) -> Option> { @@ -2026,7 +2020,10 @@ fn parse_ast( match *initial_parsed_ast { Some(_) => Some(initial_parsed_ast.take().unwrap()), None => match query_router.parse(message) { - Ok(ast) => Some(ast), + Ok(ast) => { + let _ = query_router.infer(&ast); + Some(ast) + } Err(error) => { warn!( "Query parsing error: {} (client: {})", From cc29d6d37c56e107f6ddb3312a43015dd8b03d7d Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Mon, 2 Oct 2023 23:55:41 -0700 Subject: [PATCH 10/15] Fixed some remaining auto-sharding issues. --- src/client.rs | 326 ++++++++++++++++++++++++++++----------------- src/client_xact.rs | 54 ++++---- src/server_xact.rs | 4 + 3 files changed, 240 insertions(+), 144 deletions(-) diff --git a/src/client.rs b/src/client.rs index 18fafced..07d36cb6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,7 +17,6 @@ use crate::auth_passthrough::refetch_auth_hash; use crate::client_xact::*; use crate::config::{ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, - Role, }; use crate::constants::*; use crate::messages::*; @@ -91,7 +90,7 @@ pub struct Client { last_server_stats: Option>, /// Last server key we talked to. - pub(crate) last_server_key: Option<(usize, Option)>, + pub(crate) last_server_key: Option, /// Connected to server connected_to_server: bool, @@ -1012,7 +1011,7 @@ where query_router.update_pool_settings(pool.settings.clone()); - let mut is_initial_message = true; + let mut initial_message = Some(message); let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { 0 => tokio::time::Duration::MAX, @@ -1020,7 +1019,7 @@ where }; let mut all_conns: HashMap< - (usize, Option), + ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), > = HashMap::new(); @@ -1038,6 +1037,116 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { + let is_first_message_to_server = all_conns.is_empty(); + + let mut message = match initial_message { + None => { + trace!("Waiting for message inside transaction or in session mode"); + + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); + + match tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) + .await + { + Ok(Ok(message)) => message, + Ok(Err(err)) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + self.stats.disconnect(); + for (_, mut conn) in all_conns { + let server = &mut *conn.0; + server.checkin_cleanup().await?; + } + + return Err(err); + } + Err(_) => { + // Client idle in transaction timeout + error_response_with_state( + &mut self.write, + "idle transaction timeout", + self.xact_info.state(), + ) + .await?; + error!( + "Client idle in transaction timeout: \ + {{ \ + pool_name: {}, \ + username: {}, \ + shard: {:?}, \ + role: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); + + break; + } + } + } + + Some(message) => { + initial_message = None; + message + } + }; + + // Safe to unwrap because we know this message has a certain length and has the code + // This reads the first byte without advancing the internal pointer and mutating the bytes + let code = *message.get(0).unwrap() as char; + let mut ast = None; + + match code { + // Query + 'Q' => { + // If the first message is a `BEGIN` statement, then we are starting a + // transaction. However, we might not still be on the right shard (as the + // shard might be inferred from the first query). So we parse the query and + // store the `BEGIN` statement. Upon receiving the next query (and possibly + // determining the shard), we will execute the `BEGIN` statement. + if let Some(ast_vec) = initial_parsed_ast.as_ref() { + if Self::is_begin_statement(ast_vec) { + assert_eq!(ast_vec.len(), 1); + + initialize_xact_info(self, &ast_vec[0]); + + custom_protocol_response_ok_with_state( + &mut self.write, + "BEGIN", + self.xact_info.state(), + ) + .await?; + + continue; + } + } + + if query_router.query_parser_enabled() { + let should_continue; + (should_continue, ast) = self + .parse_ast_helper( + &mut query_router, + &mut initial_parsed_ast, + &message, + &client_identifier, + ) + .await?; + if !should_continue { + continue; + } + } + } + _ => (), + }; + + assign_client_transaction_state(self, &all_conns); + if all_conns.is_empty() || self.is_transparent_mode() { let current_shard = query_router.shard(); @@ -1171,12 +1280,13 @@ where self.stats.waiting(); } - let server_key = (query_router.shard().unwrap_or(0), query_router.role()); + let query_router_shard = query_router.shard(); + let server_key = query_router_shard.unwrap_or(0); let mut conn_opt = all_conns.get_mut(&server_key); if conn_opt.is_none() { // Grab a server from the pool. let connection = match pool - .get(query_router.shard(), query_router.role(), &self.stats) + .get(query_router_shard, query_router.role(), &self.stats) .await { Ok(conn) => { @@ -1237,7 +1347,7 @@ where all_conns.iter_mut().next().unwrap(); let first_server = &mut *first_conn.0; - if !acquire_gid_and_snapshot(self, first_server_key, first_server) + if !acquire_gid_and_snapshot(self, *first_server_key, first_server) .await? { break; @@ -1256,7 +1366,7 @@ where "Sending implicit BEGIN statement to server {} (in transparent mode with distributed transaction)", address ); - if !begin_distributed_xact(self, &server_key, server).await? { + if !begin_distributed_xact(self, server_key, server).await? { break; } Some(conn) @@ -1289,10 +1399,8 @@ where server.sync_parameters(&self.server_parameters).await?; } - assign_client_transaction_state(self, &all_conns); - let is_distributed_xact = all_conns.len() > 1; - let server_key = (query_router.shard().unwrap_or(0), query_router.role()); + let server_key = query_router.shard().unwrap_or(0); let conn = all_conns.get_mut(&server_key).unwrap(); let server = &mut *conn.0; let address = &conn.1; @@ -1348,103 +1456,36 @@ where prepared_statement = None; } - if !is_initial_message { - is_initial_message = false; - } else { - trace!("Waiting for message inside transaction or in session mode"); - - // This is not an initial message so discard the initial_parsed_ast - initial_parsed_ast.take(); - - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { - Ok(Ok(msg)) => message = msg, - Ok(Err(err)) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - self.stats.disconnect(); - for (_, mut conn) in all_conns { - let server = &mut *conn.0; - server.checkin_cleanup().await?; - } - - return Err(err); - } - Err(_) => { - // Client idle in transaction timeout - error_response_with_state( - &mut self.write, - "idle transaction timeout", - self.xact_info.state(), - ) - .await?; - error!( - "Client idle in transaction timeout: \ - {{ \ - pool_name: {}, \ - username: {}, \ - shard: {:?}, \ - role: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role() - ); - - break; - } - }; - }; - // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. - // Safe to unwrap because we know this message has a certain length and has the code - // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; - trace!("client Message: {}", code); match code { // Query 'Q' => { - let mut ast = None; - if is_distributed_xact || query_router.query_parser_enabled() { - // We don't want to parse again if we already parsed it as the initial message - ast = parse_ast( - &mut initial_parsed_ast, - &mut query_router, - &message, - &client_identifier, - ); + if is_distributed_xact { + // if we are in a distributed transaction, we need to parse the query + // to figure out if it's a COMMIT or ABORT statement. + // If query parsing is disabled, we need to parse it here. Otherwise, + // it's already parsed. + if !query_router.query_parser_enabled() { + assert_eq!(ast, None); + let should_continue; + (should_continue, ast) = self + .parse_ast_helper( + &mut query_router, + &mut initial_parsed_ast, + &message, + &client_identifier, + ) + .await?; + if !should_continue { + continue; + } + } if let Some(ast) = &ast { - let plugin_result = query_router.execute_plugins(ast).await; - - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response_with_state( - &mut self.write, - &error, - self.xact_info.state(), - ) - .await?; - continue; - } - - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } - - _ => (), - }; - if is_distributed_xact { if set_commit_or_abort_statement(self, &ast) { break; @@ -1454,7 +1495,30 @@ where } debug!("Sending query to server (in Query mode)"); - let server_was_in_transaction = server.in_transaction(); + // If this is the first message that we're actually sending to the server, + // we need to send the 'BEGIN' statement first if it was issued before this. + // This is the case when the client is in transparent mode and A 'BEGIN' + // statement was issued before the first query. + if is_first_message_to_server { + if let Some(begin_stmt) = self.xact_info.get_begin_statement() { + let begin_stmt = begin_stmt.clone(); + if let Some(err) = + query_server(self, server, &begin_stmt.to_string()).await? + { + error_response_stmt( + &mut self.write, + &err, + self.xact_info.state(), + ) + .await?; + break; + } + + initialize_xact_params(self, server, &begin_stmt); + + assert!(server.in_transaction()); + } + } self.send_and_receive_loop( code, @@ -1466,27 +1530,7 @@ where ) .await?; - if server.in_transaction() { - // If the server was not in transaction and now it is, we need to store the - // begin statement. The begin statement is used if/when contacting another - // server in the same transaction. - if !server_was_in_transaction { - if ast.is_none() { - // We don't want to parse again if we already parsed it as the initial message - ast = parse_ast( - &mut initial_parsed_ast, - &mut query_router, - &message, - &client_identifier, - ); - } - assert!(ast.is_some()); - let ast_vec = ast.unwrap(); - assert_eq!(ast_vec.len(), 1); - - initialize_xact_info(self, server, &ast_vec[0]); - } - } else { + if !server.in_transaction() { // Report transaction executed statistics. self.stats.transaction(); server @@ -1756,6 +1800,40 @@ where } } + async fn parse_ast_helper( + &mut self, + query_router: &mut QueryRouter, + initial_parsed_ast: &mut Option>, + message: &BytesMut, + client_identifier: &ClientIdentifier, + ) -> Result<(bool, Option>), Error> { + Ok({ + // We don't want to parse again if we already parsed it as the initial message + let ast = parse_ast(initial_parsed_ast, query_router, message, client_identifier); + + if let Some(ast_ref) = &ast { + let plugin_result = query_router.execute_plugins(ast_ref).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response_with_state(&mut self.write, &error, self.xact_info.state()) + .await?; + (false, None) + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + (false, None) + } + + _ => (true, ast), + } + } else { + (true, None) + } + }) + } + /// Retrieve connection pool, if it exists. /// Return an error to the client otherwise. async fn get_pool(&mut self) -> Result { @@ -2008,6 +2086,14 @@ where pub fn is_in_idle_transaction(&self) -> bool { self.xact_info.is_idle() } + + fn is_begin_statement(ast_vec: &[sqlparser::ast::Statement]) -> bool { + ast_vec.len() == 1 + && matches!( + ast_vec[0], + sqlparser::ast::Statement::StartTransaction { .. } + ) + } } fn parse_ast( diff --git a/src/client_xact.rs b/src/client_xact.rs index 427c65c9..d48d9e07 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -9,17 +9,19 @@ use sqlparser::ast::{Statement, TransactionAccessMode, TransactionMode}; use std::collections::HashMap; use uuid::Uuid; -use crate::config::{Address, Role}; +use crate::config::Address; use crate::messages::*; use crate::server::Server; use crate::server_xact::*; +pub type ServerId = usize; + /// This function starts a distributed transaction by sending a BEGIN statement to the first server. /// It is called on the first server, as soon as client wants to interact with another server, /// which hints that the client wants to start a distributed transaction. pub async fn begin_distributed_xact( clnt: &mut Client, - server_key: &(usize, Option), + server_key: ServerId, server: &mut Server, ) -> Result where @@ -74,7 +76,7 @@ where /// Also, if the transaction is repeatable read or higher, it acquires a snapshot from the server. pub async fn acquire_gid_and_snapshot( clnt: &mut Client, - server_key: &(usize, Option), + server_key: ServerId, server: &mut Server, ) -> Result where @@ -134,17 +136,14 @@ fn generate_xact_gid(clnt: &Client) -> String { /// Generates a server-specific GID for a transaction. We need this, because it's possible that /// multiple servers might actually be the same server (which commonly happens in testing). -fn gen_server_specific_gid(server_key: &(usize, Option), gid: &str) -> String { - format!("{}_{}", server_key.0, gid) +fn gen_server_specific_gid(server_key: ServerId, gid: &str) -> String { + format!("{}_{}", server_key, gid) } /// Assigns the transaction state based on the state of all servers. pub fn assign_client_transaction_state( clnt: &mut Client, - all_conns: &HashMap< - (usize, Option), - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, + all_conns: &HashMap, Address)>, ) { clnt.xact_info.set_state(if all_conns.is_empty() { // if there's no server, we're in idle mode. @@ -162,10 +161,7 @@ pub fn assign_client_transaction_state( /// Returns true if any server is in a failed transaction. fn is_any_server_in_failed_xact( - all_conns: &HashMap< - (usize, Option), - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, + all_conns: &HashMap, Address)>, ) -> bool { all_conns .iter() @@ -173,15 +169,26 @@ fn is_any_server_in_failed_xact( } /// This function initializes the transaction parameters based on the first BEGIN statement. -pub fn initialize_xact_info( +pub fn initialize_xact_info(clnt: &mut Client, begin_stmt: &Statement) { + if let Statement::StartTransaction { .. } = begin_stmt { + // This is the first BEGIN statement. We need to register it for later executions. + clnt.xact_info.set_begin_statement(Some(begin_stmt.clone())); + + clnt.xact_info.set_state(TransactionState::InTransaction); + } else { + // If we were not in a transaction and the first statement is + // not a BEGIN, then it's an irrecovable error. + assert!(false); + } +} + +/// This function initializes the transaction parameters based on the server's default. +pub fn initialize_xact_params( clnt: &mut Client, server: &mut Server, begin_stmt: &Statement, ) { if let Statement::StartTransaction { modes } = begin_stmt { - // This is the first BEGIN statement. We need - clnt.xact_info.set_begin_statement(Some(begin_stmt.clone())); - // Initialize transaction parameters using the server's default. clnt.xact_info.params = server_default_transaction_parameters(server); for mode in modes { @@ -207,8 +214,7 @@ pub fn initialize_xact_info( // Set the transaction parameters on the first server. server.transaction_metadata_mut().params = clnt.xact_info.params.clone(); } else { - // If we were not in a transaction and the first statement is - // not a BEGIN, then it's an irrecovable error. + // If it's not a BEGIN, then it's an irrecovable error. assert!(false); } } @@ -222,7 +228,7 @@ pub fn initialize_xact_info( pub async fn distributed_commit_or_abort( clnt: &mut Client, all_conns: &mut HashMap< - (usize, Option), + ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, ) -> Result<(), Error> @@ -321,7 +327,7 @@ pub fn reset_client_xact(clnt: &mut Client) { async fn distributed_commit( clnt: &mut Client, all_conns: &mut HashMap< - (usize, Option), + ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, ) -> Result, Error> @@ -387,7 +393,7 @@ fn set_post_query_state(clnt: &mut Client, server: &mut Server) { async fn distributed_abort( clnt: &mut Client, all_conns: &mut HashMap< - (usize, Option), + ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, abort_stmt: &String, @@ -423,7 +429,7 @@ where async fn distributed_prepare( clnt: &mut Client, all_conns: &mut HashMap< - (usize, Option), + ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, ) -> Result, Error> @@ -456,7 +462,7 @@ where } /// This function is called when the client sends a query to the server without requiring an answer. -async fn query_server( +pub async fn query_server( clnt: &mut Client, server: &mut Server, stmt: &str, diff --git a/src/server_xact.rs b/src/server_xact.rs index 78e14fe5..eec4a4e9 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -306,6 +306,10 @@ impl TransactionMetaData { pub fn get_abort_statement(&self) -> Option<&Statement> { self.abort_statement.as_ref() } + + pub fn is_transaction_started(&self) -> bool { + self.begin_statement.is_some() + } } impl Default for TransactionMetaData { From 300c9d83ce06038c3332a6660366681ebd83e5d0 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 6 Oct 2023 09:54:37 -0700 Subject: [PATCH 11/15] Fixed the issue with registering query execution time. --- src/server.rs | 6 +----- src/server_xact.rs | 13 ------------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/src/server.rs b/src/server.rs index 9445a3bd..43d405eb 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,7 +10,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::mem; use std::net::IpAddr; use std::sync::Arc; -use std::time::{Instant, SystemTime}; +use std::time::SystemTime; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::TcpStream; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; @@ -1281,8 +1281,6 @@ impl Server { self.send(&query).await?; - let query_start = Instant::now(); - loop { let mut response = self.recv(None).await?; @@ -1296,8 +1294,6 @@ impl Server { } } - query_time_stats(self, query_start); - Ok(None) } diff --git a/src/server_xact.rs b/src/server_xact.rs index eec4a4e9..618bc2dd 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -6,7 +6,6 @@ use log::{debug, warn}; use once_cell::sync::Lazy; use sqlparser::ast::{Statement, TransactionIsolationLevel}; use std::collections::HashMap; -use std::time::Instant; use crate::errors::Error; use crate::messages::*; @@ -344,15 +343,6 @@ impl RWConflict { } } -/// Given the query start time, this function registers the query execution time to the server -/// stats. -pub fn query_time_stats(server: &mut Server, query_start: Instant) { - server.stats().query( - Instant::now().duration_since(query_start).as_millis() as u64, - &server.server_parameters.get_application_name(), - ); -} - /// Execute an arbitrary query against the server. /// It will use the simple query protocol. /// Result will be returned, so this is useful for things like `SELECT`. @@ -369,7 +359,6 @@ async fn query_with_response( server.send(&query).await?; - let query_start = Instant::now(); // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. let mut response = server.recv(None).await?; @@ -407,8 +396,6 @@ async fn query_with_response( _ => return Err(Error::ServerError), }; - query_time_stats(server, query_start); - Ok(Either::Left(query_response)) } From a4a554d14ac7fac1f9144482f7b934763a038f40 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Fri, 6 Oct 2023 10:14:49 -0700 Subject: [PATCH 12/15] Applied cargo clippy suggestions. --- src/client.rs | 114 +++++++++++++++++++----------------------- src/client_xact.rs | 67 +++++++++++-------------- src/query_messages.rs | 1 + src/server.rs | 41 +++++---------- src/server_xact.rs | 12 ++--- 5 files changed, 98 insertions(+), 137 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9c6cb865..5abefb4f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -138,7 +138,7 @@ pub async fn client_entrypoint( // Client requested a TLS connection. Ok((ClientConnectionType::Tls, _)) => { // TLS settings are configured, will setup TLS now. - if tls_certificate != None { + if tls_certificate.is_some() { debug!("Accepting TLS request"); let mut yes = BytesMut::new(); @@ -455,7 +455,7 @@ where None => "pgcat", }; - let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); + let client_identifier = ClientIdentifier::new(application_name, username, pool_name); let admin = ["pgcat", "pgbouncer"] .iter() @@ -806,7 +806,7 @@ where let mut will_prepare = false; let client_identifier = ClientIdentifier::new( - &self.server_parameters.get_application_name(), + self.server_parameters.get_application_name(), &self.username, &self.pool_name, ); @@ -993,15 +993,11 @@ where } // Check on plugin results. - match plugin_output { - Some(PluginOutput::Deny(error)) => { - self.buffer.clear(); - error_response(&mut self.write, &error).await?; - plugin_output = None; - continue; - } - - _ => (), + if let Some(PluginOutput::Deny(error)) = plugin_output { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; }; // Check if the pool is paused and wait until it's resumed. @@ -1100,50 +1096,47 @@ where // Safe to unwrap because we know this message has a certain length and has the code // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; + let code = *message.first().unwrap() as char; let mut ast = None; - match code { - // Query - 'Q' => { - // If the first message is a `BEGIN` statement, then we are starting a - // transaction. However, we might not still be on the right shard (as the - // shard might be inferred from the first query). So we parse the query and - // store the `BEGIN` statement. Upon receiving the next query (and possibly - // determining the shard), we will execute the `BEGIN` statement. - if let Some(ast_vec) = initial_parsed_ast.as_ref() { - if Self::is_begin_statement(ast_vec) { - assert_eq!(ast_vec.len(), 1); - - initialize_xact_info(self, &ast_vec[0]); - - custom_protocol_response_ok_with_state( - &mut self.write, - "BEGIN", - self.xact_info.state(), - ) - .await?; + // Query + if code == 'Q' { + // If the first message is a `BEGIN` statement, then we are starting a + // transaction. However, we might not still be on the right shard (as the + // shard might be inferred from the first query). So we parse the query and + // store the `BEGIN` statement. Upon receiving the next query (and possibly + // determining the shard), we will execute the `BEGIN` statement. + if let Some(ast_vec) = initial_parsed_ast.as_ref() { + if Self::is_begin_statement(ast_vec) { + assert_eq!(ast_vec.len(), 1); - continue; - } + initialize_xact_info(self, &ast_vec[0]); + + custom_protocol_response_ok_with_state( + &mut self.write, + "BEGIN", + self.xact_info.state(), + ) + .await?; + + continue; } + } - if query_router.query_parser_enabled() { - let should_continue; - (should_continue, ast) = self - .parse_ast_helper( - &mut query_router, - &mut initial_parsed_ast, - &message, - &client_identifier, - ) - .await?; - if !should_continue { - continue; - } + if query_router.query_parser_enabled() { + let should_continue; + (should_continue, ast) = self + .parse_ast_helper( + &mut query_router, + &mut initial_parsed_ast, + &message, + &client_identifier, + ) + .await?; + if !should_continue { + continue; } } - _ => (), }; assign_client_transaction_state(self, &all_conns); @@ -1487,10 +1480,8 @@ where } if let Some(ast) = &ast { - if is_distributed_xact { - if set_commit_or_abort_statement(self, &ast) { - break; - } + if is_distributed_xact && set_commit_or_abort_statement(self, ast) { + break; } } } @@ -1536,7 +1527,7 @@ where self.stats.transaction(); server .stats() - .transaction(&self.server_parameters.get_application_name()); + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction or transparent modes. // If we are in session mode, we keep the server until the client disconnects. @@ -1613,13 +1604,10 @@ where let close: Close = (&message).try_into()?; if close.is_prepared_statement() && !close.anonymous() { - match self.prepared_statements.get(&close.name) { - Some(parse) => { - server.will_close(&parse.generated_name); - } - + if let Some(parse) = self.prepared_statements.get(&close.name) { + server.will_close(&parse.generated_name); + } else { // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. - None => (), }; } } @@ -1663,7 +1651,7 @@ where self.buffer.put(&message[..]); - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' && !prepared_statements_enabled { @@ -1695,7 +1683,7 @@ where self.stats.transaction(); server .stats() - .transaction(&self.server_parameters.get_application_name()); + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction or transparent modes. // If we are in session mode, we keep the server until the client disconnects. @@ -2011,7 +1999,7 @@ where client_stats.query(); server.stats().query( Instant::now().duration_since(query_start).as_millis() as u64, - &self.server_parameters.get_application_name(), + self.server_parameters.get_application_name(), ); Ok(()) diff --git a/src/client_xact.rs b/src/client_xact.rs index d48d9e07..8a847a70 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -2,6 +2,7 @@ use crate::client::Client; use crate::errors::Error; use crate::query_messages::{ErrorInfo, ErrorResponse, Message}; use bytes::BytesMut; +use core::panic; use futures::future::join_all; use itertools::Either; use log::{debug, warn}; @@ -127,11 +128,7 @@ where /// Generates a random GID (i.e., Global transaction ID) for a transaction. fn generate_xact_gid(clnt: &Client) -> String { - format!( - "txn_{}_{}", - clnt.addr.to_string(), - Uuid::new_v4().to_string() - ) + format!("txn_{}_{}", clnt.addr, Uuid::new_v4()) } /// Generates a server-specific GID for a transaction. We need this, because it's possible that @@ -148,14 +145,12 @@ pub fn assign_client_transaction_state( clnt.xact_info.set_state(if all_conns.is_empty() { // if there's no server, we're in idle mode. TransactionState::Idle + } else if is_any_server_in_failed_xact(all_conns) { + // if any server is in failed transaction, we're in failed transaction. + TransactionState::InFailedTransaction } else { - if is_any_server_in_failed_xact(all_conns) { - // if any server is in failed transaction, we're in failed transaction. - TransactionState::InFailedTransaction - } else { - // if we have at least one server and it is in a transaction, we're in a transaction. - TransactionState::InTransaction - } + // if we have at least one server and it is in a transaction, we're in a transaction. + TransactionState::InTransaction }); } @@ -165,7 +160,7 @@ fn is_any_server_in_failed_xact( ) -> bool { all_conns .iter() - .any(|(_, conn)| in_failed_transaction(&*conn.0)) + .any(|(_, conn)| in_failed_transaction(&conn.0)) } /// This function initializes the transaction parameters based on the first BEGIN statement. @@ -177,8 +172,8 @@ pub fn initialize_xact_info(clnt: &mut Client, begin_stmt: &Statemen clnt.xact_info.set_state(TransactionState::InTransaction); } else { // If we were not in a transaction and the first statement is - // not a BEGIN, then it's an irrecovable error. - assert!(false); + // not a BEGIN, then it's an irrecoverable error. + panic!("The first statement in a transaction is not a BEGIN statement."); } } @@ -200,9 +195,7 @@ pub fn initialize_xact_params( }); } TransactionMode::IsolationLevel(isolation_level) => { - clnt.xact_info - .params - .set_isolation_level(isolation_level.clone()); + clnt.xact_info.params.set_isolation_level(*isolation_level); } } } @@ -214,8 +207,8 @@ pub fn initialize_xact_params( // Set the transaction parameters on the first server. server.transaction_metadata_mut().params = clnt.xact_info.params.clone(); } else { - // If it's not a BEGIN, then it's an irrecovable error. - assert!(false); + // If it's not a BEGIN, then it's an irrecoverable error. + panic!("The statement is not a BEGIN statement."); } } @@ -238,9 +231,9 @@ where { let dist_commit = clnt.xact_info.get_commit_statement(); let dist_abort = clnt.xact_info.get_abort_statement(); - Ok(if dist_commit.is_some() || dist_abort.is_some() { + if dist_commit.is_some() || dist_abort.is_some() { // if either a commit or abort statement is set, we should be in a distributed transaction. - assert!(all_conns.len() > 0); + assert!(!all_conns.is_empty()); let is_chained = should_be_chained(dist_commit, dist_abort); let dist_commit = dist_commit.map(|stmt| stmt.to_string()); @@ -316,7 +309,8 @@ where last_server.address() ); } - }) + } + Ok(()) } pub fn reset_client_xact(clnt: &mut Client) { @@ -396,7 +390,7 @@ async fn distributed_abort( ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), >, - abort_stmt: &String, + abort_stmt: &str, ) -> Result, Error> where S: tokio::io::AsyncRead + std::marker::Unpin, @@ -414,7 +408,7 @@ where set_post_query_state(clnt, server); server .stats() - .transaction(&clnt.server_parameters.get_application_name()); + .transaction(clnt.server_parameters.get_application_name()); }); for abort_res in abort_results { @@ -426,6 +420,7 @@ where Ok(None) } +#[allow(clippy::type_complexity)] async fn distributed_prepare( clnt: &mut Client, all_conns: &mut HashMap< @@ -479,10 +474,10 @@ where /// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort /// statement on the client. pub fn set_commit_or_abort_statement(clnt: &mut Client, ast: &Vec) -> bool { - if is_commit_statement(&ast) { + if is_commit_statement(ast) { clnt.xact_info.set_commit_statement(Some(ast[0].clone())); true - } else if is_abort_statement(&ast) { + } else if is_abort_statement(ast) { clnt.xact_info.set_abort_statement(Some(ast[0].clone())); true } else { @@ -493,12 +488,9 @@ pub fn set_commit_or_abort_statement(clnt: &mut Client, ast: &Vec) -> bool { for statement in ast { - match *statement { - Statement::Commit { .. } => { - assert_eq!(ast.len(), 1); - return true; - } - _ => (), + if let Statement::Commit { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; } } false @@ -507,12 +499,9 @@ fn is_commit_statement(ast: &Vec) -> bool { /// Returns true if the statement is an abort statement. fn is_abort_statement(ast: &Vec) -> bool { for statement in ast { - match *statement { - Statement::Rollback { .. } => { - assert_eq!(ast.len(), 1); - return true; - } - _ => (), + if let Statement::Rollback { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; } } false diff --git a/src/query_messages.rs b/src/query_messages.rs index 13e6d1f2..f695c95d 100644 --- a/src/query_messages.rs +++ b/src/query_messages.rs @@ -321,6 +321,7 @@ pub struct ErrorInfo { } impl ErrorInfo { + #[allow(clippy::too_many_arguments)] pub fn new( severity: String, code: String, diff --git a/src/server.rs b/src/server.rs index 43d405eb..c5f3ab31 100644 --- a/src/server.rs +++ b/src/server.rs @@ -203,12 +203,8 @@ impl ServerParameters { key = "DateStyle".to_string(); }; - if TRACKED_PARAMETERS.contains(&key) { + if TRACKED_PARAMETERS.contains(&key) || startup { self.parameters.insert(key, value); - } else { - if startup { - self.parameters.insert(key, value); - } } } @@ -338,6 +334,7 @@ pub struct Server { impl Server { /// Pretend to be the Postgres client and connect to the server given host, port and credentials. /// Perform the authentication and return the server in a ready for query state. + #[allow(clippy::too_many_arguments)] pub async fn startup( address: &Address, user: &User, @@ -446,10 +443,7 @@ impl Server { // Something else? m => { - return Err(Error::SocketError(format!( - "Unknown message: {}", - m as char - ))); + return Err(Error::SocketError(format!("Unknown message: {}", { m }))); } } } else { @@ -467,26 +461,17 @@ impl Server { None => &user.username, }; - let password = match user.server_password { - Some(ref server_password) => Some(server_password), - None => match user.password { - Some(ref password) => Some(password), - None => None, - }, - }; + let password = user.server_password.as_ref(); startup(&mut stream, username, database).await?; let mut process_id: i32 = 0; let mut secret_key: i32 = 0; - let server_identifier = ServerIdentifier::new(username, &database); + let server_identifier = ServerIdentifier::new(username, database); // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. - let mut scram: Option = match password { - Some(password) => Some(ScramSha256::new(password)), - None => None, - }; + let mut scram: Option = password.map(|password| ScramSha256::new(password)); let mut server_parameters = ServerParameters::new(); @@ -891,7 +876,7 @@ impl Server { self.mirror_send(messages); self.stats().data_sent(messages.len()); - match write_all_flush(&mut self.stream, &messages).await { + match write_all_flush(&mut self.stream, messages).await { Ok(_) => { // Successfully sent to server self.last_activity = SystemTime::now(); @@ -1361,16 +1346,14 @@ impl Server { } pub fn mirror_send(&mut self, bytes: &BytesMut) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.send(bytes), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.send(bytes) } } pub fn mirror_disconnect(&mut self) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.disconnect(), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.disconnect() } } @@ -1399,7 +1382,7 @@ impl Server { server.send(&simple_query(query)).await?; let mut message = server.recv(None).await?; - Ok(parse_query_message(&mut message).await?) + parse_query_message(&mut message).await } pub fn transaction_metadata(&self) -> &TransactionMetaData { diff --git a/src/server_xact.rs b/src/server_xact.rs index 618bc2dd..19e27cfd 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -14,11 +14,11 @@ use crate::server::{Server, ServerParameters}; /// The default transaction parameters that might be configured on the server. pub static TRANSACTION_PARAMETERS: Lazy> = Lazy::new(|| { - let mut list = Vec::new(); - list.push("default_transaction_isolation".to_string()); - list.push("default_transaction_read_only".to_string()); - list.push("default_transaction_deferrable".to_string()); - list + vec![ + "default_transaction_isolation".to_string(), + "default_transaction_read_only".to_string(), + "default_transaction_deferrable".to_string(), + ] }); /// The default transaction parameters that are either configured on the server or set by the @@ -415,7 +415,7 @@ pub async fn acquire_xact_snapshot( assert!(qres.row_desc().fields().len() == 1); assert!(qres_rows.len() == 1); if let Some(snapshot) = qres_rows[0].fields().get(0).unwrap() { - let snapshot = std::str::from_utf8(&snapshot).unwrap().to_string(); + let snapshot = std::str::from_utf8(snapshot).unwrap().to_string(); debug!("Got snapshot: {}", snapshot); From 735fad31dfe9958023fc807bdae8079142d835ee Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Tue, 10 Oct 2023 23:57:21 -0700 Subject: [PATCH 13/15] Improved the code quality. --- src/client.rs | 1684 ++++++++++++++++++++++------------------- src/client_xact.rs | 835 ++++++++++---------- src/errors.rs | 4 + src/messages.rs | 18 +- src/query_messages.rs | 32 +- src/server.rs | 37 +- src/server_xact.rs | 481 ++++-------- 7 files changed, 1514 insertions(+), 1577 deletions(-) diff --git a/src/client.rs b/src/client.rs index ef70d163..8af3d386 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,15 @@ use crate::errors::{ClientIdentifier, Error}; use crate::pool::BanReason; +use crate::server_xact::TransactionState; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use once_cell::sync::Lazy; +use sqlparser::ast::Statement; use std::collections::HashMap; +use std::ops::ControlFlow; use std::sync::{atomic::AtomicUsize, Arc}; -use std::time::Instant; +use std::time::{Duration, Instant}; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; @@ -24,7 +27,6 @@ use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; -use crate::server_xact::*; use crate::stats::{ClientStats, ServerStats}; use crate::tls::Tls; @@ -110,7 +112,7 @@ pub struct Client { /// Prepared statements prepared_statements: HashMap, - pub(crate) xact_info: TransactionMetaData, + pub(crate) xact_info: ClientTxnMetaData, } /// Client entrypoint. @@ -143,7 +145,7 @@ pub async fn client_entrypoint( let mut yes = BytesMut::new(); yes.put_u8(b'S'); - write_all(&mut stream, yes).await?; + write_all(&mut stream, &yes).await?; // Negotiate TLS. match startup_tls(stream, client_server_map, shutdown, admin_only).await { @@ -178,7 +180,7 @@ pub async fn client_entrypoint( // Rejecting client request for TLS. let mut no = BytesMut::new(); no.put_u8(b'N'); - write_all(&mut stream, no).await?; + write_all(&mut stream, &no).await?; // Attempting regular startup. Client can disconnect now // if they choose. @@ -555,7 +557,7 @@ where error_response( &mut write, &format!( - "No pool configured for database: {:?}, user: {:?} (in startup)", + "No pool configured for database: {:?}, user: {:?}", pool_name, username ), ) @@ -687,7 +689,7 @@ where debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, (&server_parameters).into()).await?; + write_all(&mut write, &(&server_parameters).into()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; @@ -763,33 +765,171 @@ where }) } - /// Handle a connected and authenticated client. - pub async fn handle(&mut self) -> Result<(), Error> { - // The client wants to cancel a query it has issued previously. - if self.cancel_mode { - trace!("Sending CancelRequest"); + async fn handle_cancel_mode(&mut self) -> Result<(), Error> { + trace!("Sending CancelRequest"); + let (process_id, secret_key, address, port) = { + let guard = self.client_server_map.lock(); + + match guard.get(&(self.process_id, self.secret_key)) { + // Drop the mutex as soon as possible. + // We found the server the client is using for its query + // that it wants to cancel. + Some((process_id, secret_key, address, port)) => { + (*process_id, *secret_key, address.clone(), *port) + } - let (process_id, secret_key, address, port) = { - let guard = self.client_server_map.lock(); + // The client doesn't know / got the wrong server, + // we're closing the connection for security reasons. + None => return Ok(()), + } + }; - match guard.get(&(self.process_id, self.secret_key)) { - // Drop the mutex as soon as possible. - // We found the server the client is using for its query - // that it wants to cancel. - Some((process_id, secret_key, address, port)) => { - (*process_id, *secret_key, address.clone(), *port) + // Opens a new separate connection to the server, sends the backend_id + // and secret_key and then closes it for security reasons. No other interactions + // take place. + Server::cancel(&address, port, process_id, secret_key).await + } + + #[allow(clippy::too_many_arguments)] + async fn handle_message_in_custom_protocol_loop( + &mut self, + mut message: BytesMut, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + will_prepare: &mut bool, + prepared_statements_enabled: &mut bool, + prepared_statement: &mut Option, + plugin_output: &mut Option, + ) -> Result>), ()>, Error> { + let mut initial_parsed_ast = None; + match message[0] as char { + // Buffer extended protocol messages even if we do not have + // a server connection yet. Hopefully, when we get the S message + // we'll be able to allocate a connection. Also, clients do not expect + // the server to respond to these messages so even if we were not able to + // allocate a connection, we wouldn't be able to send back an error message + // to the client so we buffer them and defer the decision to error out or not + // to when we get the S message + 'D' => { + if *prepared_statements_enabled { + let name; + (name, message) = self.rewrite_describe(message).await?; + + if let Some(name) = name { + *prepared_statement = Some(name); } + } + + self.buffer.put(&message[..]); + return Ok(ControlFlow::Continue(())); + } - // The client doesn't know / got the wrong server, - // we're closing the connection for security reasons. - None => return Ok(()), + 'E' => { + self.buffer.put(&message[..]); + return Ok(ControlFlow::Continue(())); + } + + 'Q' => { + if query_router.query_parser_enabled() { + match query_router.parse(&message) { + Ok(ast) => { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + return Ok(ControlFlow::Continue(())); + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, &result).await?; + return Ok(ControlFlow::Continue(())); + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + + initial_parsed_ast = Some(ast); + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + } } - }; + } + + 'P' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_parse(message)?; + *will_prepare = true; + } + + self.buffer.put(&message[..]); - // Opens a new separate connection to the server, sends the backend_id - // and secret_key and then closes it for security reasons. No other interactions - // take place. - return Server::cancel(&address, port, process_id, secret_key).await; + if query_router.query_parser_enabled() { + match query_router.parse(&message) { + Ok(ast) => { + if let Ok(output) = query_router.execute_plugins(&ast).await { + *plugin_output = Some(output); + } + + let _ = query_router.infer(&ast); + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + }; + } + + return Ok(ControlFlow::Continue(())); + } + + 'B' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_bind(message).await?; + } + + self.buffer.put(&message[..]); + + if query_router.query_parser_enabled() { + query_router.infer_shard_from_bind(&message); + } + + return Ok(ControlFlow::Continue(())); + } + + // Close (F) + 'C' => { + if *prepared_statements_enabled { + let close: Close = (&message).try_into()?; + + if close.is_prepared_statement() && !close.anonymous() { + self.prepared_statements.remove(&close.name); + write_all_flush(&mut self.write, &close_complete()).await?; + return Ok(ControlFlow::Continue(())); + } + } + } + + _ => (), + } + + Ok(ControlFlow::Break((message, initial_parsed_ast))) + } + + /// Handle a connected and authenticated client. + pub async fn handle(&mut self) -> Result<(), Error> { + // The client wants to cancel a query it has issued previously. + if self.cancel_mode { + return self.handle_cancel_mode().await; } // The query router determines where the query is going to go, @@ -870,127 +1010,26 @@ where let mut pool = self.get_pool().await?; query_router.update_pool_settings(pool.settings.clone()); - let mut initial_parsed_ast = None; - - match message[0] as char { - // Buffer extended protocol messages even if we do not have - // a server connection yet. Hopefully, when we get the S message - // we'll be able to allocate a connection. Also, clients do not expect - // the server to respond to these messages so even if we were not able to - // allocate a connection, we wouldn't be able to send back an error message - // to the client so we buffer them and defer the decision to error out or not - // to when we get the S message - 'D' => { - if prepared_statements_enabled { - let name; - (name, message) = self.rewrite_describe(message).await?; - - if let Some(name) = name { - prepared_statement = Some(name); - } - } - - self.buffer.put(&message[..]); - continue; - } - - 'E' => { - self.buffer.put(&message[..]); - continue; - } - - 'Q' => { - if query_router.query_parser_enabled() { - match query_router.parse(&message) { - Ok(ast) => { - let plugin_result = query_router.execute_plugins(&ast).await; - - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; - continue; - } - - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } - - _ => (), - }; - - let _ = query_router.infer(&ast); - - initial_parsed_ast = Some(ast); - } - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - } - } - } - } - - 'P' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_parse(message)?; - will_prepare = true; - } - - self.buffer.put(&message[..]); - - if query_router.query_parser_enabled() { - match query_router.parse(&message) { - Ok(ast) => { - if let Ok(output) = query_router.execute_plugins(&ast).await { - plugin_output = Some(output); - } - - let _ = query_router.infer(&ast); - } - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - } - }; - } - - continue; - } - - 'B' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_bind(message).await?; - } - - self.buffer.put(&message[..]); - - if query_router.query_parser_enabled() { - query_router.infer_shard_from_bind(&message); - } + let initial_parsed_ast; + + (message, initial_parsed_ast) = match self + .handle_message_in_custom_protocol_loop( + message, + &client_identifier, + &mut query_router, + &mut will_prepare, + &mut prepared_statements_enabled, + &mut prepared_statement, + &mut plugin_output, + ) + .await? + { + ControlFlow::Break(res) => res, + ControlFlow::Continue(()) => { continue; } - - // Close (F) - 'C' => { - if prepared_statements_enabled { - let close: Close = (&message).try_into()?; - - if close.is_prepared_statement() && !close.anonymous() { - self.prepared_statements.remove(&close.name); - write_all_flush(&mut self.write, &close_complete()).await?; - continue; - } - } - } - - _ => (), - } + }; // Check on plugin results. if let Some(PluginOutput::Deny(error)) = plugin_output { @@ -1008,157 +1047,288 @@ where query_router.update_pool_settings(pool.settings.clone()); - let mut initial_message = Some(message); - - let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { - 0 => tokio::time::Duration::MAX, - timeout => tokio::time::Duration::from_millis(timeout), - }; + // Reset transaction state, as we are entering a new transaction loop. + self.reset_client_xact(); let mut all_conns: HashMap< ServerId, (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), > = HashMap::new(); - // Reset transaction state, as we are entering a new transaction loop. - reset_client_xact(self); + match self + .handle_transaction_loop( + &client_identifier, + &mut query_router, + &pool, + &mut all_conns, + Some(message), + initial_parsed_ast, + &mut will_prepare, + &mut prepared_statements_enabled, + &mut prepared_statement, + &mut plugin_output, + ) + .await? + { + Some(ControlFlow::Break(())) => { + return Ok(()); + } - // if we want to jump out of the transaction loop, but still want to `continue` in the - // custom protocol loop, we set this flag to true. - let mut should_continue_in_outer_loop = false; + Some(ControlFlow::Continue(())) => { + continue; + } - // Transaction loop. Multiple queries can be issued by the client here. - // The connection belongs to the client until the transaction is over, - // or until the client disconnects if we are in session mode. - // - // If the client is in session mode, no more custom protocol - // commands will be accepted. - loop { - let is_first_message_to_server = all_conns.is_empty(); + _ => (), + } - let mut message = match initial_message { - None => { - trace!("Waiting for message inside transaction or in session mode"); + self.cleanup_custom_protocol_loop_helper(all_conns, prepared_statements_enabled) + .await?; + } + } - // This is not an initial message so discard the initial_parsed_ast - initial_parsed_ast.take(); + async fn cleanup_custom_protocol_loop_helper( + &mut self, + mut all_conns: HashMap< + usize, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + prepared_statements_enabled: bool, + ) -> Result<(), Error> { + self.distributed_commit_or_abort(&mut all_conns).await?; - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { - Ok(Ok(message)) => message, - Ok(Err(err)) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - self.stats.disconnect(); - for (_, mut conn) in all_conns { - let server = &mut *conn.0; - server.checkin_cleanup().await?; - } + // Reset transaction state for safety reasons. Even if this state will be reset before + // the next transaction, this dirty state could be seen in-between here and there. + self.reset_client_xact(); - return Err(err); - } - Err(_) => { - // Client idle in transaction timeout - error_response_with_state( - &mut self.write, - "idle transaction timeout", - self.xact_info.state(), - ) - .await?; - error!( - "Client idle in transaction timeout: \ + debug!("Releasing servers back into the pool"); + for conn in all_conns.values_mut() { + let server = &mut *conn.0; + let address = &conn.1; + + // The server is no longer bound to us, we can't cancel it's queries anymore. + debug!("Releasing server back into the pool: {}", address); + + server.checkin_cleanup().await?; + + if prepared_statements_enabled { + server.maintain_cache().await?; + } + + server.stats().idle(); + } + self.connected_to_server = false; + self.release(); + self.stats.idle(); + Ok(()) + } + async fn read_message_helper( + &mut self, + idle_client_timeout_duration: Duration, + query_router: &mut QueryRouter, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + initial_message: Option, + initial_parsed_ast: &mut Option>, + ) -> Result, Error> { + match initial_message { + None => { + trace!("Waiting for message inside transaction or in session mode"); + + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); + + match tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) + .await + { + Ok(Ok(message)) => Ok(ControlFlow::Continue(message)), + Ok(Err(err)) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + self.stats.disconnect(); + for conn in all_conns.values_mut() { + let server = &mut *conn.0; + server.checkin_cleanup().await?; + } + + Err(err) + } + Err(_) => { + // Client idle in transaction timeout + error_response_with_state( + &mut self.write, + "idle transaction timeout", + self.xact_info.state(), + ) + .await?; + error!( + "Client idle in transaction timeout: \ {{ \ pool_name: {}, \ username: {}, \ shard: {:?}, \ role: \"{:?}\" \ }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role() - ); + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); - break; - } - } + Ok(ControlFlow::Break(())) } + } + } + + Some(message) => Ok(ControlFlow::Continue(message)), + } + } - Some(message) => { - initial_message = None; - message + async fn handle_begin_statement( + &mut self, + code: char, + message: &BytesMut, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + initial_parsed_ast: &mut Option>, + ) -> Result>, ()>, Error> { + // Query + if code == 'Q' { + // If the first message is a `BEGIN` statement, then we are starting a + // transaction. However, we might not still be on the right shard (as the + // shard might be inferred from the first query). So we parse the query and + // store the `BEGIN` statement. Upon receiving the next query (and possibly + // determining the shard), we will execute the `BEGIN` statement. + if let Some(ast_vec) = initial_parsed_ast { + if Self::is_begin_statement(ast_vec) { + assert_eq!(ast_vec.len(), 1); + + let begin_stmt = &ast_vec[0]; + + if let Statement::StartTransaction { .. } = begin_stmt { + // This is the first BEGIN statement. We need to register it for later executions. + self.xact_info.set_begin_statement(Some(begin_stmt.clone())); + + self.xact_info.set_state(TransactionState::InTransaction); + } else { + panic!("Expected BEGIN statement, got {:?}", begin_stmt); } - }; - // Safe to unwrap because we know this message has a certain length and has the code - // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.first().unwrap() as char; - let mut ast = None; + custom_protocol_response_ok_with_state( + &mut self.write, + "BEGIN", + self.xact_info.state(), + ) + .await?; - // Query - if code == 'Q' { - // If the first message is a `BEGIN` statement, then we are starting a - // transaction. However, we might not still be on the right shard (as the - // shard might be inferred from the first query). So we parse the query and - // store the `BEGIN` statement. Upon receiving the next query (and possibly - // determining the shard), we will execute the `BEGIN` statement. - if let Some(ast_vec) = initial_parsed_ast.as_ref() { - if Self::is_begin_statement(ast_vec) { - assert_eq!(ast_vec.len(), 1); - - initialize_xact_info(self, &ast_vec[0]); - - custom_protocol_response_ok_with_state( - &mut self.write, - "BEGIN", - self.xact_info.state(), - ) - .await?; + return Ok(ControlFlow::Continue(())); + } + } - continue; - } - } + if query_router.query_parser_enabled() { + return self + .parse_ast_helper(query_router, initial_parsed_ast, message, client_identifier) + .await; + } + }; + Ok(ControlFlow::Break(None)) + } - if query_router.query_parser_enabled() { - let should_continue; - (should_continue, ast) = self - .parse_ast_helper( - &mut query_router, - &mut initial_parsed_ast, - &message, - &client_identifier, - ) - .await?; - if !should_continue { - continue; - } - } - }; + #[allow(clippy::too_many_arguments)] + async fn handle_transaction_loop<'a, 'b>( + &mut self, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + pool: &'a ConnectionPool, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'b, crate::pool::ServerPool>, Address), + >, + mut initial_message: Option, + mut initial_parsed_ast: Option>, + will_prepare: &mut bool, + prepared_statements_enabled: &mut bool, + prepared_statement: &mut Option, + plugin_output: &mut Option, + ) -> Result>, Error> + where + 'a: 'b, + { + let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { + 0 => tokio::time::Duration::MAX, + timeout => tokio::time::Duration::from_millis(timeout), + }; + + // Transaction loop. Multiple queries can be issued by the client here. + // The connection belongs to the client until the transaction is over, + // or until the client disconnects if we are in session mode. + // + // If the client is in session or transaction modes, no more custom protocol + // commands will be accepted. However, in transparent mode, the `SET SHARD` and + // `SET SHARDING KEY` custom protocol commands are still accepted. + loop { + let is_first_message_to_server = all_conns.is_empty(); + + let mut message: BytesMut = match self + .read_message_helper( + idle_client_timeout_duration, + query_router, + all_conns, + initial_message.take(), + &mut initial_parsed_ast, + ) + .await? + { + ControlFlow::Continue(message) => message, + ControlFlow::Break(_) => { + break; + } + }; + + // Safe to unwrap because we know this message has a certain length and has the code + // This reads the first byte without advancing the internal pointer and mutating the bytes + let code = *message.first().unwrap() as char; + let mut ast = match self + .handle_begin_statement( + code, + &message, + client_identifier, + query_router, + &mut initial_parsed_ast, + ) + .await? + { + ControlFlow::Continue(()) => { + continue; + } - assign_client_transaction_state(self, &all_conns); + ControlFlow::Break(ast) => ast, + }; + + self.assign_client_transaction_state(all_conns); - if all_conns.is_empty() || self.is_transparent_mode() { - let current_shard = query_router.shard(); + if all_conns.is_empty() || self.is_transparent_mode() { + let current_shard = query_router.shard(); - // Handle all custom protocol commands, if any. - match query_router.try_execute_command(&message) { - // Normal query, not a custom command. - None => (), + // Handle all custom protocol commands, if any. + match query_router.try_execute_command(&message) { + // Normal query, not a custom command. + None => (), - // SET SHARD TO - Some((Command::SetShard, _)) => { - match query_router.shard() { - None => (), - Some(selected_shard) => { - if selected_shard >= pool.shards() { - // Bad shard number, send error message to client. - query_router.set_shard(current_shard); + // SET SHARD TO + Some((Command::SetShard, _)) => { + match query_router.shard() { + None => (), + Some(selected_shard) => { + if selected_shard >= pool.shards() { + // Bad shard number, send error message to client. + query_router.set_shard(current_shard); - error_response_with_state( + error_response_with_state( &mut self.write, &format!( "shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)", @@ -1169,146 +1339,141 @@ where self.xact_info.state(), ) .await?; - } else { - custom_protocol_response_ok_with_state( - &mut self.write, - "SET SHARD", - self.xact_info.state(), - ) - .await?; - } + } else { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARD", + self.xact_info.state(), + ) + .await?; } } - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } } - - // SET PRIMARY READS TO - Some((Command::SetPrimaryReads, _)) => { - custom_protocol_response_ok_with_state( - &mut self.write, - "SET PRIMARY READS", - self.xact_info.state(), - ) - .await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } + } - // SET SHARDING KEY TO - Some((Command::SetShardingKey, _)) => { - custom_protocol_response_ok_with_state( - &mut self.write, - "SET SHARDING KEY", - self.xact_info.state(), - ) - .await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + // SET PRIMARY READS TO + Some((Command::SetPrimaryReads, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET PRIMARY READS", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } - - // SET SERVER ROLE TO - Some((Command::SetServerRole, _)) => { - custom_protocol_response_ok_with_state( - &mut self.write, - "SET SERVER ROLE", - self.xact_info.state(), - ) - .await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + } + + // SET SHARDING KEY TO + Some((Command::SetShardingKey, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARDING KEY", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } + } - // SHOW SERVER ROLE - Some((Command::ShowServerRole, value)) => { - show_response(&mut self.write, "server role", &value).await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + // SET SERVER ROLE TO + Some((Command::SetServerRole, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SERVER ROLE", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } + } - // SHOW SHARD - Some((Command::ShowShard, value)) => { - show_response(&mut self.write, "shard", &value).await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + // SHOW SERVER ROLE + Some((Command::ShowServerRole, value)) => { + show_response(&mut self.write, "server role", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } + } - // SHOW PRIMARY READS - Some((Command::ShowPrimaryReads, value)) => { - show_response(&mut self.write, "primary reads", &value).await?; - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } + // SHOW SHARD + Some((Command::ShowShard, value)) => { + show_response(&mut self.write, "shard", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } - }; + } - debug!("Waiting for connection from pool"); - if !self.admin { - self.stats.waiting(); + // SHOW PRIMARY READS + Some((Command::ShowPrimaryReads, value)) => { + show_response(&mut self.write, "primary reads", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } } + }; - let query_router_shard = query_router.shard(); - let server_key = query_router_shard.unwrap_or(0); - let mut conn_opt = all_conns.get_mut(&server_key); - if conn_opt.is_none() { - // Grab a server from the pool. - let connection = match pool - .get(query_router_shard, query_router.role(), &self.stats) - .await - { - Ok(conn) => { - debug!("Got connection from pool"); - conn - } - Err(err) => { - // Client is attempting to get results from the server, - // but we were unable to grab a connection from the pool - // We'll send back an error message and clean the extended - // protocol buffer - self.stats.idle(); - - if message[0] as char == 'S' { - error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); - } + debug!("Waiting for connection from pool"); + if !self.admin { + self.stats.waiting(); + } - error_response_with_state( - &mut self.write, - format!("could not get connection from the pool - {}", err) - .as_str(), - self.xact_info.state(), - ) - .await?; + let query_router_shard = query_router.shard(); + let server_key = query_router_shard.unwrap_or(0); + let mut conn_opt = all_conns.get_mut(&server_key); + if conn_opt.is_none() { + // Grab a server from the pool. + let connection = match pool + .get(query_router_shard, query_router.role(), &self.stats) + .await + { + Ok(conn) => { + debug!("Got connection from pool"); + conn + } + Err(err) => { + // Client is attempting to get results from the server, + // but we were unable to grab a connection from the pool + // We'll send back an error message and clean the extended + // protocol buffer + self.stats.idle(); + if message[0] as char == 'S' { error!( - "Could not get connection from pool: \ + "Got Sync message but failed to get a connection from the pool" + ); + self.buffer.clear(); + } + + error_response_with_state( + &mut self.write, + format!("could not get connection from the pool - {}", err) + .as_str(), + self.xact_info.state(), + ) + .await?; + + error!( + "Could not get connection from pool: \ {{ \ pool_name: {:?}, \ username: {:?}, \ @@ -1316,477 +1481,434 @@ where role: \"{:?}\", \ error: \"{:?}\" \ }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role(), - err - ); + self.pool_name, + self.username, + query_router.shard(), + query_router.role(), + err + ); - if self.is_in_idle_transaction() { - should_continue_in_outer_loop = true; - break; - } else { - continue; - } - } - }; - - // Before inserting this new connection, if we had only a single connection - // before, then it means that we have started a distributed transaction. - // At this point, we need to acquire the snapshot from the first server and - // use that snapshot for all the other servers. - if all_conns.len() == 1 { - let (first_server_key, first_conn) = - all_conns.iter_mut().next().unwrap(); - let first_server = &mut *first_conn.0; - - if !acquire_gid_and_snapshot(self, *first_server_key, first_server) - .await? - { - break; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; } } + }; - all_conns.insert(server_key, connection); - let is_distributed_xact = all_conns.len() > 1; - conn_opt = if is_distributed_xact { - let conn_opt = all_conns.get_mut(&server_key); - let conn = conn_opt.unwrap(); - let server = &mut *conn.0; - let address = &conn.1; + // Before inserting this new connection, if we had only a single connection + // before, then it means that we have started a distributed transaction. + // At this point, we need to do some prep on the first server. + if all_conns.len() == 1 { + let (first_server_key, first_conn) = all_conns.iter_mut().next().unwrap(); + let first_server = &mut *first_conn.0; + + if !self.acquire_gid(*first_server_key, first_server).await? { + break; + } + } + + all_conns.insert(server_key, connection); - debug!( + let is_distributed_xact = all_conns.len() > 1; + conn_opt = if is_distributed_xact { + let conn_opt = all_conns.get_mut(&server_key); + let conn = conn_opt.unwrap(); + let server = &mut *conn.0; + let address = &conn.1; + + debug!( "Sending implicit BEGIN statement to server {} (in transparent mode with distributed transaction)", address ); - if !begin_distributed_xact(self, server_key, server).await? { - break; - } - Some(conn) - } else { - all_conns.get_mut(&server_key) + if !self.begin_distributed_xact(server_key, server).await? { + break; } + Some(conn) + } else { + all_conns.get_mut(&server_key) } - let conn = conn_opt.unwrap(); - let address = &conn.1; - let server = &mut *conn.0; - - // Server is assigned to the client in case the client wants to - // cancel a query later. - server.claim(self.process_id, self.secret_key); - self.connected_to_server = true; - - // Update statistics - self.stats.active(); - - self.last_address_id = Some(address.id); - self.last_server_stats = Some(server.stats()); - self.last_server_key = Some(server_key); - - debug!( - "Client {:?} talking to server {:?}", - self.addr, - server.address() - ); - - server.sync_parameters(&self.server_parameters).await?; } - - let is_distributed_xact = all_conns.len() > 1; - let server_key = query_router.shard().unwrap_or(0); - let conn = all_conns.get_mut(&server_key).unwrap(); - let server = &mut *conn.0; + let conn = conn_opt.unwrap(); let address = &conn.1; + let server = &mut *conn.0; - // Only check if we should rewrite prepared statements - // in session mode. In transaction mode, we check at the beginning of - // each transaction. - if !self.is_transaction_mode() { - prepared_statements_enabled = get_prepared_statements(); - } + // Server is assigned to the client in case the client wants to + // cancel a query later. + server.claim(self.process_id, self.secret_key); + self.connected_to_server = true; - debug!("Prepared statement active: {:?}", prepared_statement); + // Update statistics + self.stats.active(); - // We are processing a prepared statement. - if let Some(ref name) = prepared_statement { - debug!("Checking prepared statement is on server"); - // Get the prepared statement the server expects to see. - let statement = match self.prepared_statements.get(name) { - Some(statement) => { - debug!("Prepared statement `{}` found in cache", name); - statement - } - None => { - return Err(Error::ClientError(format!( - "prepared statement `{}` not found", - name - ))) - } - }; + self.last_address_id = Some(address.id); + self.last_server_stats = Some(server.stats()); + self.last_server_key = Some(server_key); - // Since it's already in the buffer, we don't need to prepare it on this server. - if will_prepare { - server.will_prepare(&statement.name); - will_prepare = false; - } else { - // The statement is not prepared on the server, so we need to prepare it. - if server.should_prepare(&statement.name) { - match server.prepare(statement).await { - Ok(_) => (), - Err(err) => { - pool.ban( - address, - BanReason::MessageSendFailed, - Some(&self.stats), - ); - return Err(err); - } - } - } - } + debug!( + "Client {:?} talking to server {:?}", + self.addr, + server.address() + ); - // Done processing the prepared statement. - prepared_statement = None; - } + server.sync_parameters(&self.server_parameters).await?; + } - // The message will be forwarded to the server intact. We still would like to - // parse it below to figure out what to do with it. - - trace!("Message: {}", code); - - match code { - // Query - 'Q' => { - if is_distributed_xact { - // if we are in a distributed transaction, we need to parse the query - // to figure out if it's a COMMIT or ABORT statement. - // If query parsing is disabled, we need to parse it here. Otherwise, - // it's already parsed. - if !query_router.query_parser_enabled() { - assert_eq!(ast, None); - let should_continue; - (should_continue, ast) = self - .parse_ast_helper( - &mut query_router, - &mut initial_parsed_ast, - &message, - &client_identifier, - ) - .await?; - if !should_continue { - continue; - } - } + let is_distributed_xact = all_conns.len() > 1; + let server_key = query_router.shard().unwrap_or(0); + let conn = all_conns.get_mut(&server_key).unwrap(); + let server = &mut *conn.0; + let address = &conn.1; + + // Only check if we should rewrite prepared statements + // in session mode. In transaction mode, we check at the beginning of + // each transaction. + if !self.is_transaction_mode() { + *prepared_statements_enabled = get_prepared_statements(); + } - if let Some(ast) = &ast { - if is_distributed_xact && set_commit_or_abort_statement(self, ast) { - break; - } - } - } - debug!("Sending query to server (in Query mode)"); - - // If this is the first message that we're actually sending to the server, - // we need to send the 'BEGIN' statement first if it was issued before this. - // This is the case when the client is in transparent mode and A 'BEGIN' - // statement was issued before the first query. - if is_first_message_to_server { - if let Some(begin_stmt) = self.xact_info.get_begin_statement() { - let begin_stmt = begin_stmt.clone(); - if let Some(err) = - query_server(self, server, &begin_stmt.to_string()).await? - { - error_response_stmt( - &mut self.write, - &err, - self.xact_info.state(), - ) - .await?; - break; - } + debug!("Prepared statement active: {:?}", *prepared_statement); - initialize_xact_params(self, server, &begin_stmt); + // We are processing a prepared statement. + if let Some(ref name) = *prepared_statement { + debug!("Checking prepared statement is on server"); + // Get the prepared statement the server expects to see. + let statement = match self.prepared_statements.get(name) { + Some(statement) => { + debug!("Prepared statement `{}` found in cache", name); + statement + } + None => { + return Err(Error::ClientError(format!( + "prepared statement `{}` not found", + name + ))) + } + }; - assert!(server.in_transaction()); + // Since it's already in the buffer, we don't need to prepare it on this server. + if *will_prepare { + server.will_prepare(&statement.name); + *will_prepare = false; + } else { + // The statement is not prepared on the server, so we need to prepare it. + if server.should_prepare(&statement.name) { + match server.prepare(statement).await { + Ok(_) => (), + Err(err) => { + pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats)); + return Err(err); } } + } + } - self.send_and_receive_loop( - code, - Some(&message), - server, - address, - &pool, - &self.stats.clone(), - ) - .await?; + // Done processing the prepared statement. + *prepared_statement = None; + } + + // The message will be forwarded to the server intact. We still would like to + // parse it below to figure out what to do with it. + + trace!("Message: {}", code); - if !server.in_transaction() { - // Report transaction executed statistics. - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); - - // Release server back to the pool if we are in transaction or transparent modes. - // If we are in session mode, we keep the server until the client disconnects. - if (self.is_transaction_mode() || self.is_transparent_mode()) - && !server.in_copy_mode() + match code { + // Query + 'Q' => { + if is_distributed_xact { + // if we are in a distributed transaction, we need to parse the query + // to figure out if it's a COMMIT or ABORT statement. + // If query parsing is disabled, we need to parse it here. Otherwise, + // it's already parsed. + if !query_router.query_parser_enabled() { + assert_eq!(ast, None); + ast = match self + .parse_ast_helper( + query_router, + &mut initial_parsed_ast, + &message, + client_identifier, + ) + .await? { - self.stats.idle(); + ControlFlow::Continue(()) => { + continue; + } + + ControlFlow::Break(ast) => ast, + } + } + if let Some(ast) = &ast { + if is_distributed_xact && self.set_commit_or_abort_statement(ast) { break; } } } + debug!("Sending query to server"); + + // If this is the first message that we're actually sending to the server, + // we need to send the 'BEGIN' statement first if it was issued before this. + // This is the case when the client is in transparent mode and A 'BEGIN' + // statement was issued before the first query. + if is_first_message_to_server { + if let Some(begin_stmt) = self.xact_info.get_begin_statement() { + let begin_stmt = begin_stmt.clone(); + let res = server.query(&begin_stmt.to_string()).await; + + if self.post_query_processing(server, res).await?.is_none() { + break; + } - // Terminate - 'X' => { - server.checkin_cleanup().await?; - self.stats.disconnect(); - self.release(); + self.initialize_xact_params(server, &begin_stmt); - if prepared_statements_enabled { - server.maintain_cache().await?; + assert!(server.in_transaction()); } - - return Ok(()); } - // Parse - // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - 'P' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_parse(message)?; - will_prepare = true; - } + self.send_and_receive_loop( + code, + Some(&message), + server, + address, + pool, + &self.stats.clone(), + ) + .await?; - if query_router.query_parser_enabled() { - if let Ok(ast) = query_router.parse(&message) { - if let Ok(output) = query_router.execute_plugins(&ast).await { - plugin_output = Some(output); - } - } - } + if !server.in_transaction() { + // Report transaction executed statistics. + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); + + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { + self.stats.idle(); - self.buffer.put(&message[..]); + break; + } } + } - // Bind - // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' - 'B' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_bind(message).await?; - } + // Terminate + 'X' => { + server.checkin_cleanup().await?; + self.stats.disconnect(); + self.release(); - self.buffer.put(&message[..]); + if *prepared_statements_enabled { + server.maintain_cache().await?; } - // Describe - // Command a client can issue to describe a previously prepared named statement. - 'D' => { - if prepared_statements_enabled { - let name; - (name, message) = self.rewrite_describe(message).await?; - - if let Some(name) = name { - prepared_statement = Some(name); - } - } + return Ok(Some(ControlFlow::Break(()))); + } - self.buffer.put(&message[..]); + // Parse + // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. + 'P' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_parse(message)?; + *will_prepare = true; } - // Close the prepared statement. - 'C' => { - if prepared_statements_enabled { - let close: Close = (&message).try_into()?; - - if close.is_prepared_statement() && !close.anonymous() { - if let Some(parse) = self.prepared_statements.get(&close.name) { - server.will_close(&parse.generated_name); - } else { - // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. - }; + if query_router.query_parser_enabled() { + if let Ok(ast) = query_router.parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + *plugin_output = Some(output); } } - - self.buffer.put(&message[..]); } - // Execute - // Execute a prepared statement prepared in `P` and bound in `B`. - 'E' => { - self.buffer.put(&message[..]); - } + self.buffer.put(&message[..]); + } - // Sync - // Frontend (client) is asking for the query result now. - 'S' => { - debug!("Sending query to server (in Sync mode)"); + // Bind + // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' + 'B' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_bind(message).await?; + } - match plugin_output { - Some(PluginOutput::Deny(error)) => { - error_response_with_state( - &mut self.write, - &error, - self.xact_info.state(), - ) - .await?; - plugin_output = None; - self.buffer.clear(); - continue; - } + self.buffer.put(&message[..]); + } - Some(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - plugin_output = None; - self.buffer.clear(); - continue; - } + // Describe + // Command a client can issue to describe a previously prepared named statement. + 'D' => { + if *prepared_statements_enabled { + let name; + (name, message) = self.rewrite_describe(message).await?; - _ => (), - }; + if let Some(name) = name { + *prepared_statement = Some(name); + } + } - self.buffer.put(&message[..]); + self.buffer.put(&message[..]); + } - let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; + // Close the prepared statement. + 'C' => { + if *prepared_statements_enabled { + let close: Close = (&message).try_into()?; - // Almost certainly true - if first_message_code == 'P' && !prepared_statements_enabled { - // Message layout - // P followed by 32 int followed by null-terminated statement name - // So message code should be in offset 0 of the buffer, first character - // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); - if first_char_in_name != 0 { - // This is a named prepared statement - // Server connection state will need to be cleared at checkin - server.mark_dirty(); - } + if close.is_prepared_statement() && !close.anonymous() { + if let Some(parse) = self.prepared_statements.get(&close.name) { + server.will_close(&parse.generated_name); + } else { + // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. + }; } + } - self.send_and_receive_loop( - code, - None, - server, - address, - &pool, - &self.stats.clone(), - ) - .await?; + self.buffer.put(&message[..]); + } - self.buffer.clear(); + // Execute + // Execute a prepared statement prepared in `P` and bound in `B`. + 'E' => { + self.buffer.put(&message[..]); + } - if !server.in_transaction() { - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); + // Sync + // Frontend (client) is asking for the query result now. + 'S' => { + debug!("Sending query to server"); - // Release server back to the pool if we are in transaction or transparent modes. - // If we are in session mode, we keep the server until the client disconnects. - if (self.is_transaction_mode() || self.is_transparent_mode()) - && !server.in_copy_mode() - { - break; - } + match plugin_output { + Some(PluginOutput::Deny(error)) => { + error_response_with_state( + &mut self.write, + error, + self.xact_info.state(), + ) + .await?; + *plugin_output = None; + self.buffer.clear(); + continue; } - } - - // CopyData - 'd' => { - self.buffer.put(&message[..]); - // Want to limit buffer size - if self.buffer.len() > 8196 { - // Forward the data to the server, - self.send_server_message(server, &self.buffer, address, &pool) - .await?; + Some(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + *plugin_output = None; self.buffer.clear(); + continue; } - } - // CopyDone or CopyFail - // Copy is done, successfully or not. - 'c' | 'f' => { - // We may already have some copy data in the buffer, add this message to buffer - self.buffer.put(&message[..]); + _ => (), + }; - self.send_server_message(server, &self.buffer, address, &pool) - .await?; + self.buffer.put(&message[..]); - // Clear the buffer - self.buffer.clear(); + let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; + + // Almost certainly true + if first_message_code == 'P' && !*prepared_statements_enabled { + // Message layout + // P followed by 32 int followed by null-terminated statement name + // So message code should be in offset 0 of the buffer, first character + // in prepared statement name would be index 5 + let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + if first_char_in_name != 0 { + // This is a named prepared statement + // Server connection state will need to be cleared at checkin + server.mark_dirty(); + } + } - let response = self - .receive_server_message(server, address, &pool, &self.stats.clone()) - .await?; + self.send_and_receive_loop( + code, + None, + server, + address, + pool, + &self.stats.clone(), + ) + .await?; - match write_all_flush(&mut self.write, &response).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; + self.buffer.clear(); - if !server.in_transaction() { - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); + if !server.in_transaction() { + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction or transparent modes. - // If we are in session mode, we keep the server until the client disconnects. - if self.is_transaction_mode() || self.is_transparent_mode() { - break; - } + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { + break; } } + } + + // CopyData + 'd' => { + self.buffer.put(&message[..]); - // Some unexpected message. We either did not implement the protocol correctly - // or this is not a Postgres client we're talking to. - _ => { - error!("Unexpected code: {}", code); + // Want to limit buffer size + if self.buffer.len() > 8196 { + // Forward the data to the server, + self.send_server_message(server, &self.buffer, address, pool) + .await?; + self.buffer.clear(); } } - } - - if should_continue_in_outer_loop { - continue; - } - distributed_commit_or_abort(self, &mut all_conns).await?; + // CopyDone or CopyFail + // Copy is done, successfully or not. + 'c' | 'f' => { + // We may already have some copy data in the buffer, add this message to buffer + self.buffer.put(&message[..]); - // Reset transaction state for safety reasons. Even if this state will be reset before - // the next transaction, this dirty state could be seen in-between here and there. - reset_client_xact(self); + self.send_server_message(server, &self.buffer, address, pool) + .await?; - debug!("Releasing servers back into the pool"); + // Clear the buffer + self.buffer.clear(); - for (_, mut conn) in all_conns { - let server = &mut *conn.0; - let address = &conn.1; + let response = self + .receive_server_message(server, address, pool, &self.stats.clone()) + .await?; - // The server is no longer bound to us, we can't cancel it's queries anymore. - debug!("Releasing server back into the pool: {}", address); + match write_all_flush(&mut self.write, &response).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; - server.checkin_cleanup().await?; + if !server.in_transaction() { + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); - if prepared_statements_enabled { - server.maintain_cache().await?; + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if self.is_transaction_mode() || self.is_transparent_mode() { + break; + } + } } - server.stats().idle(); + // Some unexpected message. We either did not implement the protocol correctly + // or this is not a Postgres client we're talking to. + _ => { + error!("Unexpected code: {}", code); + } } - - self.connected_to_server = false; - - self.release(); - self.stats.idle(); } + + Ok(None) } async fn parse_ast_helper( @@ -1795,7 +1917,7 @@ where initial_parsed_ast: &mut Option>, message: &BytesMut, client_identifier: &ClientIdentifier, - ) -> Result<(bool, Option>), Error> { + ) -> Result>>, Error> { Ok({ // We don't want to parse again if we already parsed it as the initial message let ast = parse_ast(initial_parsed_ast, query_router, message, client_identifier); @@ -1807,18 +1929,18 @@ where Ok(PluginOutput::Deny(error)) => { error_response_with_state(&mut self.write, &error, self.xact_info.state()) .await?; - (false, None) + ControlFlow::Continue(()) } Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - (false, None) + write_all(&mut self.write, &result).await?; + ControlFlow::Continue(()) } - _ => (true, ast), + _ => ControlFlow::Break(ast), } } else { - (true, None) + ControlFlow::Break(None) } }) } @@ -1832,7 +1954,7 @@ where error_response( &mut self.write, &format!( - "No pool configured for database: {}, user: {} (in get_pool)", + "No pool configured for database: {}, user: {}", self.pool_name, self.username ), ) @@ -1998,7 +2120,7 @@ where // Report query executed statistics. client_stats.query(); server.stats().query( - Instant::now().duration_since(query_start).as_millis() as u64, + query_start.elapsed().as_millis() as u64, self.server_parameters.get_application_name(), ); @@ -2076,12 +2198,8 @@ where self.xact_info.is_idle() } - fn is_begin_statement(ast_vec: &[sqlparser::ast::Statement]) -> bool { - ast_vec.len() == 1 - && matches!( - ast_vec[0], - sqlparser::ast::Statement::StartTransaction { .. } - ) + fn is_begin_statement(ast_vec: &[Statement]) -> bool { + ast_vec.len() == 1 && matches!(ast_vec[0], Statement::StartTransaction { .. }) } } diff --git a/src/client_xact.rs b/src/client_xact.rs index 8a847a70..86a4b2ef 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -2,9 +2,9 @@ use crate::client::Client; use crate::errors::Error; use crate::query_messages::{ErrorInfo, ErrorResponse, Message}; use bytes::BytesMut; + use core::panic; use futures::future::join_all; -use itertools::Either; use log::{debug, warn}; use sqlparser::ast::{Statement, TransactionAccessMode, TransactionMode}; use std::collections::HashMap; @@ -17,510 +17,483 @@ use crate::server_xact::*; pub type ServerId = usize; -/// This function starts a distributed transaction by sending a BEGIN statement to the first server. -/// It is called on the first server, as soon as client wants to interact with another server, -/// which hints that the client wants to start a distributed transaction. -pub async fn begin_distributed_xact( - clnt: &mut Client, - server_key: ServerId, - server: &mut Server, -) -> Result -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - let begin_stmt = clnt.xact_info.get_begin_statement(); - assert!(begin_stmt.is_some()); - if let Some(err) = query_server(clnt, server, &begin_stmt.unwrap().to_string()).await? { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - return Ok(false); +/// The metadata of a server transaction. +#[derive(Default, Debug, Clone)] +pub struct ClientTxnMetaData { + begin_statement: Option, + commit_statement: Option, + abort_statement: Option, + + pub params: CommonTxnParams, +} + +impl ClientTxnMetaData { + pub fn set_state(&mut self, state: TransactionState) { + set_state_helper(&mut self.params.state, state); } - if clnt.xact_info.params.is_repeatable_read_or_higher() { - // If we are in a repeatable read or serializable transaction, we need to use the - // snapshot we acquired from the first server. - assert!(clnt.xact_info.get_snapshot().is_some()); - let snapshot = clnt.xact_info.get_snapshot().unwrap(); + pub fn state(&self) -> TransactionState { + self.params.state + } - debug!( - "Assigning snapshot ('{}') to server {}", - snapshot, - server.address(), - ); + pub fn is_idle(&self) -> bool { + self.params.state == TransactionState::Idle + } - let snapshot_res = assign_xact_snapshot(server, &snapshot).await?; - set_post_query_state(clnt, server); + pub fn set_xact_gid(&mut self, xact_gid: Option) { + self.params.xact_gid = xact_gid; + } - if let Some(err) = snapshot_res { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - return Ok(false); - }; + pub fn get_xact_gid(&self) -> Option { + self.params.xact_gid.clone() + } + + pub fn set_begin_statement(&mut self, begin_statement: Option) { + self.begin_statement = begin_statement; } - // If we are in a distributed transaction, we need to assign a GID to the transaction. - assert!(clnt.xact_info.get_xact_gid().is_some()); - let gid = clnt.xact_info.get_xact_gid().unwrap(); + pub fn get_begin_statement(&self) -> Option<&Statement> { + self.begin_statement.as_ref() + } + + pub fn set_commit_statement(&mut self, commit_statement: Option) { + self.commit_statement = commit_statement; + } - debug!("Assigning GID ('{}') to server {}", gid, server.address(),); + pub fn get_commit_statement(&self) -> Option<&Statement> { + self.commit_statement.as_ref() + } - let gid_res = assign_xact_gid(server, &gen_server_specific_gid(server_key, &gid)).await?; - set_post_query_state(clnt, server); - if let Some(err) = gid_res { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - return Ok(false); + pub fn set_abort_statement(&mut self, abort_statement: Option) { + self.abort_statement = abort_statement; } - Ok(true) + pub fn get_abort_statement(&self) -> Option<&Statement> { + self.abort_statement.as_ref() + } } -/// This functions generates a GID for the current transaction and sends it to the server. -/// Also, if the transaction is repeatable read or higher, it acquires a snapshot from the server. -pub async fn acquire_gid_and_snapshot( - clnt: &mut Client, - server_key: ServerId, - server: &mut Server, -) -> Result +impl Client where S: tokio::io::AsyncRead + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin, { - assert!(clnt.xact_info.get_xact_gid().is_none()); - let gid = generate_xact_gid(clnt); - - debug!( - "Acquiring GID ('{}') and snapshot from server {}", - gid, - server.address(), - ); - - // If we are in a distributed transaction, we need to assign a GID to the transaction. - let gid_res = assign_xact_gid(server, &gen_server_specific_gid(server_key, &gid)).await?; - set_post_query_state(clnt, server); - if let Some(err) = gid_res { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - return Ok(false); - } - clnt.xact_info.set_xact_gid(Some(gid)); - - if clnt.xact_info.params.is_repeatable_read_or_higher() { - // If we are in a repeatable read or serializable transaction, we need to acquire a - // snapshot from the server. - let snapshot_res = acquire_xact_snapshot(server).await?; - set_post_query_state(clnt, server); - - match snapshot_res { - Either::Left(snapshot) => { - debug!( - "Got first server snapshot: {} (on {})", - snapshot, - server.address() - ); - clnt.xact_info.set_snapshot(Some(snapshot)); - } - Either::Right(err) => { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - return Ok(false); - } + /// This function starts a distributed transaction by sending a BEGIN statement to the first server. + /// It is called on the first server, as soon as client wants to interact with another server, + /// which hints that the client wants to start a distributed transaction. + pub async fn begin_distributed_xact( + &mut self, + server_key: ServerId, + server: &mut Server, + ) -> Result { + let begin_stmt = self.xact_info.get_begin_statement(); + assert!(begin_stmt.is_some()); + let res = server.query(&begin_stmt.unwrap().to_string()).await; + if self.post_query_processing(server, res).await?.is_none() { + return Ok(false); + } + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + assert!(self.xact_info.get_xact_gid().is_some()); + let gid = self.xact_info.get_xact_gid().unwrap(); + + debug!("Assigning GID ('{}') to server {}", gid, server.address(),); + + let gid_res = server + .assign_xact_gid(&Self::gen_server_specific_gid(server_key, &gid)) + .await; + if self.post_query_processing(server, gid_res).await?.is_none() { + return Ok(false); } + + Ok(true) } - Ok(true) -} -/// Generates a random GID (i.e., Global transaction ID) for a transaction. -fn generate_xact_gid(clnt: &Client) -> String { - format!("txn_{}_{}", clnt.addr, Uuid::new_v4()) -} + /// This functions generates a GID for the current transaction and sends it to the server. + pub async fn acquire_gid( + &mut self, + server_key: ServerId, + server: &mut Server, + ) -> Result { + assert!(self.xact_info.get_xact_gid().is_none()); + let gid = self.generate_xact_gid(); + + debug!("Acquiring GID ('{}') from server {}", gid, server.address(),); + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + let gid_res = server + .assign_xact_gid(&Self::gen_server_specific_gid(server_key, &gid)) + .await; + if self.post_query_processing(server, gid_res).await?.is_none() { + return Ok(false); + } + self.xact_info.set_xact_gid(Some(gid)); + Ok(true) + } -/// Generates a server-specific GID for a transaction. We need this, because it's possible that -/// multiple servers might actually be the same server (which commonly happens in testing). -fn gen_server_specific_gid(server_key: ServerId, gid: &str) -> String { - format!("{}_{}", server_key, gid) -} + /// Generates a random GID (i.e., Global transaction ID) for a transaction. + fn generate_xact_gid(&self) -> String { + format!("txn_{}_{}", self.addr, Uuid::new_v4()) + } -/// Assigns the transaction state based on the state of all servers. -pub fn assign_client_transaction_state( - clnt: &mut Client, - all_conns: &HashMap, Address)>, -) { - clnt.xact_info.set_state(if all_conns.is_empty() { - // if there's no server, we're in idle mode. - TransactionState::Idle - } else if is_any_server_in_failed_xact(all_conns) { - // if any server is in failed transaction, we're in failed transaction. - TransactionState::InFailedTransaction - } else { - // if we have at least one server and it is in a transaction, we're in a transaction. - TransactionState::InTransaction - }); -} + /// Generates a server-specific GID for a transaction. We need this, because it's possible that + /// multiple servers might actually be the same server (which commonly happens in testing). + fn gen_server_specific_gid(server_key: ServerId, gid: &str) -> String { + format!("{}_{}", server_key, gid) + } -/// Returns true if any server is in a failed transaction. -fn is_any_server_in_failed_xact( - all_conns: &HashMap, Address)>, -) -> bool { - all_conns - .iter() - .any(|(_, conn)| in_failed_transaction(&conn.0)) -} + /// Assigns the transaction state based on the state of all servers. + pub fn assign_client_transaction_state( + &mut self, + all_conns: &HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) { + self.xact_info.set_state(if all_conns.is_empty() { + // if there's no server, we're in idle mode. + TransactionState::Idle + } else if Self::is_any_server_in_failed_xact(all_conns) { + // if any server is in failed transaction, we're in failed transaction. + TransactionState::InFailedTransaction + } else { + // if we have at least one server and it is in a transaction, we're in a transaction. + TransactionState::InTransaction + }); + } -/// This function initializes the transaction parameters based on the first BEGIN statement. -pub fn initialize_xact_info(clnt: &mut Client, begin_stmt: &Statement) { - if let Statement::StartTransaction { .. } = begin_stmt { - // This is the first BEGIN statement. We need to register it for later executions. - clnt.xact_info.set_begin_statement(Some(begin_stmt.clone())); - - clnt.xact_info.set_state(TransactionState::InTransaction); - } else { - // If we were not in a transaction and the first statement is - // not a BEGIN, then it's an irrecoverable error. - panic!("The first statement in a transaction is not a BEGIN statement."); + /// Returns true if any server is in a failed transaction. + fn is_any_server_in_failed_xact( + all_conns: &HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> bool { + all_conns + .iter() + .any(|(_, conn)| conn.0.in_failed_transaction()) } -} -/// This function initializes the transaction parameters based on the server's default. -pub fn initialize_xact_params( - clnt: &mut Client, - server: &mut Server, - begin_stmt: &Statement, -) { - if let Statement::StartTransaction { modes } = begin_stmt { - // Initialize transaction parameters using the server's default. - clnt.xact_info.params = server_default_transaction_parameters(server); - for mode in modes { - match mode { - TransactionMode::AccessMode(access_mode) => { - clnt.xact_info.params.set_read_only(match access_mode { - TransactionAccessMode::ReadOnly => true, - TransactionAccessMode::ReadWrite => false, - }); - } - TransactionMode::IsolationLevel(isolation_level) => { - clnt.xact_info.params.set_isolation_level(*isolation_level); + /// This function initializes the transaction parameters based on the server's default. + pub fn initialize_xact_params(&mut self, server: &mut Server, begin_stmt: &Statement) { + if let Statement::StartTransaction { modes } = begin_stmt { + // Initialize transaction parameters using the server's default. + self.xact_info.params = server.server_default_transaction_parameters(); + for mode in modes { + match mode { + TransactionMode::AccessMode(access_mode) => { + self.xact_info.params.set_read_only(match access_mode { + TransactionAccessMode::ReadOnly => true, + TransactionAccessMode::ReadWrite => false, + }); + } + TransactionMode::IsolationLevel(isolation_level) => { + self.xact_info.params.set_isolation_level(*isolation_level); + } } } + debug!( + "Transaction paramaters after the first BEGIN statement: {:?}", + self.xact_info.params + ); + + // Set the transaction parameters on the first server. + server.transaction_metadata_mut().params = self.xact_info.params.clone(); + } else { + // If it's not a BEGIN, then it's an irrecoverable error. + panic!("The statement is not a BEGIN statement."); } - debug!( - "Transaction paramaters after the first BEGIN statement: {:?}", - clnt.xact_info.params - ); - - // Set the transaction parameters on the first server. - server.transaction_metadata_mut().params = clnt.xact_info.params.clone(); - } else { - // If it's not a BEGIN, then it's an irrecoverable error. - panic!("The statement is not a BEGIN statement."); } -} -/// This function performs a distribted abort/commit if necessary, and also resets the transaction -/// state. This is suppoed to be called before exiting the transaction loop. At that point, if -/// either an abort or commit statement is set, we need to perform a distributed abort/commit. This -/// is based on the logic that an abort or commit statement is only set if we are in a distributed -/// transaction and we observe a commit or abort statement sent to the server. That is where we exit -/// the transaction loop and expect this function to takeover and abort/commit the transaction. -pub async fn distributed_commit_or_abort( - clnt: &mut Client, - all_conns: &mut HashMap< - ServerId, - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, -) -> Result<(), Error> -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - let dist_commit = clnt.xact_info.get_commit_statement(); - let dist_abort = clnt.xact_info.get_abort_statement(); - if dist_commit.is_some() || dist_abort.is_some() { - // if either a commit or abort statement is set, we should be in a distributed transaction. - assert!(!all_conns.is_empty()); - - let is_chained = should_be_chained(dist_commit, dist_abort); - let dist_commit = dist_commit.map(|stmt| stmt.to_string()); - let mut dist_abort = dist_abort.map(|stmt| stmt.to_string()); - - // Report transaction executed statistics. - clnt.stats.transaction(); - - let mut is_distributed_commit_failed = false; - // We are in distributed transaction mode, and need to commit or abort on all servers. - if let Some(commit_stmt) = dist_commit { - // If two-phase commit was successful, we can send the COMMIT message to the client. - // Otherwise, we need to ROLLBACK on all servers. - if let Some(err) = distributed_commit(clnt, all_conns).await? { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - - // Currently, if a distributed commit fails, we send a ROLLBACK to all servers. - // However, this is different from how Postgres handles it. Postgres sends an - // error response to the client, and then does not accept any more queries from - // the client until the client explicitly sends a ROLLBACK. - dist_abort = Some("ROLLBACK".to_string()); - is_distributed_commit_failed = true; - } else { - custom_protocol_response_ok_with_state( - &mut clnt.write, - &commit_stmt, - TransactionState::Idle, - ) - .await?; + /// This function performs a distribted abort/commit if necessary, and also resets the transaction + /// state. This is suppoed to be called before exiting the transaction loop. At that point, if + /// either an abort or commit statement is set, we need to perform a distributed abort/commit. This + /// is based on the logic that an abort or commit statement is only set if we are in a distributed + /// transaction and we observe a commit or abort statement sent to the server. That is where we exit + /// the transaction loop and expect this function to takeover and abort/commit the transaction. + pub async fn distributed_commit_or_abort( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + let dist_commit = self.xact_info.get_commit_statement(); + let dist_abort = self.xact_info.get_abort_statement(); + if dist_commit.is_some() || dist_abort.is_some() { + // if either a commit or abort statement is set, we should be in a distributed transaction. + assert!(!all_conns.is_empty()); + + let is_chained = Self::should_be_chained(dist_commit, dist_abort); + let dist_commit = dist_commit.map(|stmt| stmt.to_string()); + let mut dist_abort = dist_abort.map(|stmt| stmt.to_string()); + + // Report transaction executed statistics. + self.stats.transaction(); + + let mut is_distributed_commit_failed = false; + // We are in distributed transaction mode, and need to commit or abort on all servers. + if let Some(commit_stmt) = dist_commit { + // If two-phase commit was successful, we can send the COMMIT message to the client. + // Otherwise, we need to ROLLBACK on all servers. + let dist_commit_res = self.distributed_commit(all_conns).await; + if self + .communicate_err_response(dist_commit_res) + .await? + .is_none() + { + // Currently, if a distributed commit fails, we send a ROLLBACK to all servers. + // However, this is different from how Postgres handles it. Postgres sends an + // error response to the client, and then does not accept any more queries from + // the client until the client explicitly sends a ROLLBACK. + dist_abort = Some("ROLLBACK".to_string()); + is_distributed_commit_failed = true; + } else { + custom_protocol_response_ok_with_state( + &mut self.write, + &commit_stmt, + TransactionState::Idle, + ) + .await?; + } } - } - if let Some(abort_stmt) = dist_abort { - let distributed_abort_res = distributed_abort(clnt, all_conns, &abort_stmt).await?; - if is_distributed_commit_failed { - // Nothing to do, as the error reponse is already sent before. - } else if let Some(err) = distributed_abort_res { - error_response_stmt(&mut clnt.write, &err, clnt.xact_info.state()).await?; - } else { - custom_protocol_response_ok_with_state( - &mut clnt.write, - &abort_stmt, - TransactionState::Idle, - ) - .await?; + if let Some(abort_stmt) = dist_abort { + let distributed_abort_res = self.distributed_abort(all_conns, &abort_stmt).await; + if is_distributed_commit_failed { + // Nothing to do, as the error reponse is already sent before. + } else if self + .communicate_err_response(distributed_abort_res) + .await? + .is_some() + { + custom_protocol_response_ok_with_state( + &mut self.write, + &abort_stmt, + TransactionState::Idle, + ) + .await?; + } } - } - let is_all_servers_in_non_copy_mode = - all_conns.iter().all(|(_, conn)| !conn.0.in_copy_mode()); + let is_all_servers_in_non_copy_mode = + all_conns.iter().all(|(_, conn)| !conn.0.in_copy_mode()); - // Release server back to the pool if we are in transaction or transparent modes. - // If we are in session mode, we keep the server until the client disconnects. - if (clnt.is_transaction_mode() || clnt.is_transparent_mode()) - && is_all_servers_in_non_copy_mode - { - clnt.stats.idle(); - } + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && is_all_servers_in_non_copy_mode + { + self.stats.idle(); + } - if is_chained { - let last_conn = all_conns - .get_mut(clnt.last_server_key.as_ref().unwrap()) - .unwrap(); - let last_server = &mut *last_conn.0; - - // TODO(MD): chained transaction should be implemented. - // Here, we need to start a local transaction on the last server. However, here is - // too late to start a transaction, as we are far from the transaction loop. We need to - // rearrange the code (or add more complicated control flow) to make it possible. - warn!( - "Chained transaction is not implemented yet. \ + if is_chained { + let last_conn = all_conns + .get_mut(self.last_server_key.as_ref().unwrap()) + .unwrap(); + let last_server = &mut *last_conn.0; + + // TODO(MD): chained transaction should be implemented. + // Here, we need to start a local transaction on the last server. However, here is + // too late to start a transaction, as we are far from the transaction loop. We need to + // rearrange the code (or add more complicated control flow) to make it possible. + warn!( + "Chained transaction is not implemented yet. \ The last server {} will NOT be in transaction.", - last_server.address() - ); + last_server.address() + ); + } } + Ok(()) } - Ok(()) -} -pub fn reset_client_xact(clnt: &mut Client) { - // Reset transaction state for safety reasons. - clnt.xact_info = Default::default(); -} + pub fn reset_client_xact(&mut self) { + // Reset transaction state for safety reasons. + self.xact_info = Default::default(); + } -async fn distributed_commit( - clnt: &mut Client, - all_conns: &mut HashMap< - ServerId, - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, -) -> Result, Error> -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - debug!("Committing distributed transaction."); - if is_any_server_in_failed_xact(all_conns) { - #[cfg(debug_assertions)] - all_conns.iter().for_each(|(server_key, conn)| { - let server = &*conn.0; - if in_failed_transaction(server) { - debug!( + async fn distributed_commit( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + debug!("Committing distributed transaction."); + if Self::is_any_server_in_failed_xact(all_conns) { + #[cfg(debug_assertions)] + all_conns.iter().for_each(|(server_key, conn)| { + let server = &*conn.0; + if server.in_failed_transaction() { + debug!( "Server {} (with server_key: {:?}) is in failed transaction. Skipping commit.", server.address(), server_key, ); - } - }); + } + }); - let err = ErrorInfo::new_brief( - "Error".to_string(), - "25P02".to_string(), - "Cannot commit a transaction that is in failed state.".to_string(), - ); + let err = ErrorInfo::new_brief( + "Error".to_string(), + "25P02".to_string(), + "Cannot commit a transaction that is in failed state.".to_string(), + ); - return Ok(Some(ErrorResponse::from(err))); - } - let res = distributed_prepare(clnt, all_conns).await?; - if res.is_right() { - return Ok(res.right()); - } + return Err(Error::ErrorResponse(ErrorResponse::from(err))); + } + self.distributed_prepare(all_conns).await?; - let commit_prepared_results = join_all(all_conns.iter_mut().map(|(_, conn)| { - let server = &mut *conn.0; - local_server_commit_prepared(server) - })) - .await; + let commit_prepared_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + server.local_server_commit_prepared() + })) + .await; - all_conns.iter_mut().for_each(|(_, conn)| { - let server = &mut *conn.0; - set_post_query_state(clnt, server); - }); + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + }); - for commit_prepared_res in commit_prepared_results { - if let Some(err) = commit_prepared_res? { + for commit_prepared_res in commit_prepared_results { // For now, we just return the first error we encounter. - return Ok(Some(err)); + commit_prepared_res?; } + + Ok(()) } - Ok(None) -} + /// After each interaction with the server, we need to set the transaction state based on the + /// server's state. + fn set_post_query_state(&mut self, server: &mut Server) { + self.xact_info + .set_state(server.transaction_metadata().state()); + } -/// After each interaction with the server, we need to set the transaction state based on the -/// server's state. -fn set_post_query_state(clnt: &mut Client, server: &mut Server) { - clnt.xact_info - .set_state(server.transaction_metadata().state()); -} + async fn distributed_abort( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + abort_stmt: &str, + ) -> Result<(), Error> { + debug!("Aborting distributed transaction"); + let abort_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + server.query(abort_stmt) + })) + .await; + + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + server + .stats() + .transaction(self.server_parameters.get_application_name()); + }); -async fn distributed_abort( - clnt: &mut Client, - all_conns: &mut HashMap< - ServerId, - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, - abort_stmt: &str, -) -> Result, Error> -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - debug!("Aborting distributed transaction"); - let abort_results = join_all(all_conns.iter_mut().map(|(_, conn)| { - let server = &mut *conn.0; - server.query(abort_stmt) - })) - .await; - - all_conns.iter_mut().for_each(|(_, conn)| { - let server = &mut *conn.0; - set_post_query_state(clnt, server); - server - .stats() - .transaction(clnt.server_parameters.get_application_name()); - }); - - for abort_res in abort_results { - if let Some(err) = abort_res? { + for abort_res in abort_results { // For now, we just return the first error we encounter. - return Ok(Some(err)); + abort_res?; } + Ok(()) } - Ok(None) -} -#[allow(clippy::type_complexity)] -async fn distributed_prepare( - clnt: &mut Client, - all_conns: &mut HashMap< - ServerId, - (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), - >, -) -> Result, Error> -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - // Apply 'PREPARE TRANSACTION' on all involved servers. - let prepare_results = join_all(all_conns.iter_mut().map(|(_, conn)| { - let server = &mut *conn.0; - set_post_query_state(clnt, server); - local_server_prepare_transaction(server) - })) - .await; - - // Update the client state based on the server state. - all_conns.iter_mut().for_each(|(_, conn)| { - let server = &mut *conn.0; - set_post_query_state(clnt, server); - }); - - // If there was any error, we need to abort the transaction. - for prepare_res in prepare_results { - if let Some(err) = prepare_res? { - // For now, we just return the first error we encounter. - return Ok(Either::Right(err)); + #[allow(clippy::type_complexity)] + async fn distributed_prepare( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + // Apply 'PREPARE TRANSACTION' on all involved servers. + let prepare_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + server.local_server_prepare_transaction() + })) + .await; + + // Update the client state based on the server state. + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + }); + + // If there was any error, we need to abort the transaction. + for prepare_res in prepare_results { + prepare_res?; } + Ok(()) } - Ok(Either::Left(())) -} -/// This function is called when the client sends a query to the server without requiring an answer. -pub async fn query_server( - clnt: &mut Client, - server: &mut Server, - stmt: &str, -) -> Result, Error> -where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, -{ - let qres = server.query(stmt).await?; - set_post_query_state(clnt, server); - Ok(qres) -} + /// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort + /// statement on the client. + pub fn set_commit_or_abort_statement(&mut self, ast: &Vec) -> bool { + if Self::is_commit_statement(ast) { + self.xact_info.set_commit_statement(Some(ast[0].clone())); + true + } else if Self::is_abort_statement(ast) { + self.xact_info.set_abort_statement(Some(ast[0].clone())); + true + } else { + false + } + } -/// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort -/// statement on the client. -pub fn set_commit_or_abort_statement(clnt: &mut Client, ast: &Vec) -> bool { - if is_commit_statement(ast) { - clnt.xact_info.set_commit_statement(Some(ast[0].clone())); - true - } else if is_abort_statement(ast) { - clnt.xact_info.set_abort_statement(Some(ast[0].clone())); - true - } else { + /// Returns true if the statement is a commit statement. + fn is_commit_statement(ast: &Vec) -> bool { + for statement in ast { + if let Statement::Commit { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; + } + } false } -} -/// Returns true if the statement is a commit statement. -fn is_commit_statement(ast: &Vec) -> bool { - for statement in ast { - if let Statement::Commit { .. } = *statement { - assert_eq!(ast.len(), 1); - return true; + /// Returns true if the statement is an abort statement. + fn is_abort_statement(ast: &Vec) -> bool { + for statement in ast { + if let Statement::Rollback { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; + } } + false } - false -} -/// Returns true if the statement is an abort statement. -fn is_abort_statement(ast: &Vec) -> bool { - for statement in ast { - if let Statement::Rollback { .. } = *statement { - assert_eq!(ast.len(), 1); - return true; + /// Returns true if the commit or abort statement should be chained. + fn should_be_chained(dist_commit: Option<&Statement>, dist_abort: Option<&Statement>) -> bool { + matches!( + (dist_commit, dist_abort), + (Some(Statement::Commit { chain: true }), _) + | (_, Some(Statement::Rollback { chain: true })) + ) + } + async fn communicate_err_response( + &mut self, + res: Result, + ) -> Result, Error> { + match res { + Ok(res) => Ok(Some(res)), + Err(Error::ErrorResponse(err)) => { + error_response_stmt(&mut self.write, &err, self.xact_info.state()).await?; + Ok(None) + } + Err(err) => Err(err), } } - false -} -/// Returns true if the commit or abort statement should be chained. -fn should_be_chained(dist_commit: Option<&Statement>, dist_abort: Option<&Statement>) -> bool { - dist_commit - .map(|stmt| match stmt { - Statement::Commit { chain } => *chain, - _ => false, - }) - .unwrap_or(false) - || dist_abort - .map(|stmt| match stmt { - Statement::Rollback { chain } => *chain, - _ => false, - }) - .unwrap_or(false) + pub async fn post_query_processing( + &mut self, + server: &mut Server, + res: Result, + ) -> Result, Error> { + self.set_post_query_state(server); + self.communicate_err_response(res).await + } } /// Send an error response to the client. diff --git a/src/errors.rs b/src/errors.rs index a6aebc50..3a312655 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,5 +1,7 @@ //! Errors. +use crate::query_messages::ErrorResponse; + /// Various errors. #[derive(Debug, PartialEq, Clone)] pub enum Error { @@ -29,6 +31,8 @@ pub enum Error { QueryRouterParserError(String), QueryRouterError(String), InvalidShardId(usize), + ErrorResponse(ErrorResponse), + IncompletePacket, } #[derive(Clone, PartialEq, Debug)] diff --git a/src/messages.rs b/src/messages.rs index fa45dcf8..c438800d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -59,7 +59,7 @@ where auth_ok.put_i32(8); auth_ok.put_i32(0); - write_all(stream, auth_ok).await + write_all(stream, &auth_ok).await } /// Generate md5 password challenge. @@ -81,7 +81,7 @@ where res.put_i32(5); // MD5 res.put_slice(&salt[..]); - write_all(stream, res).await?; + write_all(stream, &res).await?; Ok(salt) } @@ -100,7 +100,7 @@ where key_data.put_i32(backend_id); key_data.put_i32(secret_key); - write_all(stream, key_data).await + write_all(stream, &key_data).await } /// Construct a `Q`: Query message. @@ -142,7 +142,7 @@ where TransactionState::InFailedTransaction => bytes.put_u8(b'E'), } - write_all(stream, bytes).await + write_all(stream, &bytes).await } /// Send the startup packet the server. We're pretending we're a Pg client. @@ -301,7 +301,7 @@ where message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); - write_all(stream, message).await + write_all(stream, &message).await } pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> @@ -315,7 +315,7 @@ where message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); - write_all(stream, message).await + write_all(stream, &message).await } /// Implements a response to our custom `SET SHARDING KEY` @@ -455,7 +455,7 @@ where res.put(error); - write_all(stream, res).await + write_all(stream, &res).await } /// Respond to a SHOW SHARD command. @@ -613,11 +613,11 @@ pub fn flush() -> BytesMut { } /// Write all data in the buffer to the TcpStream. -pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), Err(err) => Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", diff --git a/src/query_messages.rs b/src/query_messages.rs index f695c95d..fb354d13 100644 --- a/src/query_messages.rs +++ b/src/query_messages.rs @@ -1,3 +1,5 @@ +use std::fmt; + /// Helper functions to send one-off protocol messages /// and handle TcpStream (TCP socket). use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -57,11 +59,11 @@ pub(crate) fn put_cstring(buf: &mut BytesMut, input: &str) { } /// Try to read message length from buf, without actually move the cursor -pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Option { +pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Result { if buf.remaining() >= 4 + offset { - Some((&buf[offset..4 + offset]).get_i32() as usize) + Ok((&buf[offset..4 + offset]).get_i32() as usize) } else { - None + Err(Error::IncompletePacket) } } @@ -71,18 +73,18 @@ pub(crate) fn decode_packet( buf: &mut BytesMut, offset: usize, decode_fn: F, -) -> Result, Error> +) -> Result where F: Fn(&mut BytesMut, usize) -> Result, { - if let Some(msg_len) = get_length(buf, offset) { - if buf.remaining() >= msg_len + offset { - buf.advance(offset + 4); - return decode_fn(buf, msg_len).map(|r| Some(r)); - } + let msg_len = get_length(buf, offset)?; + + if buf.remaining() >= msg_len + offset { + buf.advance(offset + 4); + return decode_fn(buf, msg_len); } - Ok(None) + Err(Error::IncompletePacket) } /// Define how message encode and decoded. @@ -121,7 +123,7 @@ pub trait Message: Sized { /// Message type and length are decoded in this implementation and it calls /// `decode_body` for remaining parts. Return `None` if the packet is not /// complete for parsing. - fn decode(buf: &mut BytesMut) -> Result, Error> { + fn decode(buf: &mut BytesMut) -> Result { let offset = Self::message_type().is_some().into(); decode_packet(buf, offset, |buf, full_len| { @@ -406,11 +408,17 @@ impl From for ErrorResponse { } /// postgres error response, sent from backend to frontend -#[derive(PartialEq, Eq, Debug, Default)] +#[derive(PartialEq, Eq, Debug, Clone, Default)] pub struct ErrorResponse { fields: Vec<(u8, String)>, } +impl fmt::Display for ErrorResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ErrorResponse({:?})", self.fields) + } +} + pub const MESSAGE_TYPE_BYTE_ERROR_RESPONSE: u8 = b'E'; impl Message for ErrorResponse { diff --git a/src/server.rs b/src/server.rs index aac72d46..dec970b0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -188,7 +188,7 @@ impl ServerParameters { false, ); server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false); - TransactionParameters::set_default_server_parameters(&mut server_parameters); + CommonTxnParams::set_default_server_parameters(&mut server_parameters); server_parameters } @@ -287,7 +287,7 @@ pub struct Server { secret_key: i32, /// Is the server inside a transaction or idle. - pub(crate) transaction_metadata: TransactionMetaData, + pub(crate) transaction_metadata: ServerTxnMetaData, /// Is there more data for the client to read. data_available: bool, @@ -833,7 +833,9 @@ impl Server { }; // We want to make sure that all servers are operating on the same isolation level. - sync_given_parameter_keys(&mut server, &TRANSACTION_PARAMETERS).await?; + server + .sync_given_parameter_keys(&TRANSACTION_PARAMETERS) + .await?; return Ok(server); } @@ -927,27 +929,28 @@ impl Server { let code = message.get_u8() as char; let _len = message.get_i32(); - trace!("recv Message: {}", code); + trace!("Message: {}", code); match code { // ReadyForQuery 'Z' => { let transaction_state = message.get_u8() as char; + let params = &mut self.transaction_metadata.params; match transaction_state { // In transaction. 'T' => { - self.transaction_metadata.state = TransactionState::InTransaction; + params.state = TransactionState::InTransaction; } // Idle, transaction over. 'I' => { - self.transaction_metadata.state = TransactionState::Idle; + params.state = TransactionState::Idle; } // Some error occurred, the transaction was rolled back. 'E' => { - self.transaction_metadata.state = TransactionState::InFailedTransaction; + params.state = TransactionState::InFailedTransaction; } // Something totally unexpected, this is not a Postgres server we know. @@ -1241,7 +1244,7 @@ impl Server { pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> { let parameter_diff = self.server_parameters.compare_params(parameters); - sync_given_parameter_key_values(self, ¶meter_diff).await + self.sync_given_parameter_key_values(¶meter_diff).await } /// Indicate that this server connection cannot be re-used and must be discarded. @@ -1267,19 +1270,19 @@ impl Server { /// Execute an arbitrary query against the server. /// It will use the simple query protocol. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. - pub async fn query(&mut self, query: &str) -> Result, Error> { + pub async fn query(&mut self, query: &str) -> Result<(), Error> { debug!("Running `{}` on server {}", query, self.address); let query = simple_query(query); self.send(&query).await?; + let mut err = None; loop { let mut response = self.recv(None).await?; - if response[0] == b'E' { - let err = ErrorResponse::decode(&mut response)?.unwrap(); - return Ok(Some(err)); + if response[0] == b'E' && err.is_none() { + err = Some(ErrorResponse::decode(&mut response)); } if !self.data_available { @@ -1287,7 +1290,11 @@ impl Server { } } - Ok(None) + if let Some(err) = err { + return Err(Error::ErrorResponse(err?)); + } + + Ok(()) } /// Perform any necessary cleanup before putting the server @@ -1393,11 +1400,11 @@ impl Server { parse_query_message(&mut message).await } - pub fn transaction_metadata(&self) -> &TransactionMetaData { + pub fn transaction_metadata(&self) -> &ServerTxnMetaData { &self.transaction_metadata } - pub fn transaction_metadata_mut(&mut self) -> &mut TransactionMetaData { + pub fn transaction_metadata_mut(&mut self) -> &mut ServerTxnMetaData { &mut self.transaction_metadata } } diff --git a/src/server_xact.rs b/src/server_xact.rs index 19e27cfd..a9c73655 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -1,15 +1,11 @@ /// Implementation of the PostgreSQL server (database) protocol. /// Here we are pretending to the a Postgres client. -use bytes::Buf; -use itertools::Either; use log::{debug, warn}; use once_cell::sync::Lazy; -use sqlparser::ast::{Statement, TransactionIsolationLevel}; +use sqlparser::ast::TransactionIsolationLevel; use std::collections::HashMap; use crate::errors::Error; -use crate::messages::*; -use crate::query_messages::{DataRow, ErrorResponse, Message, QueryResponse, RowDescription}; use crate::server::{Server, ServerParameters}; /// The default transaction parameters that might be configured on the server. @@ -24,19 +20,25 @@ pub static TRANSACTION_PARAMETERS: Lazy> = Lazy::new(|| { /// The default transaction parameters that are either configured on the server or set by the /// BEGIN statement. #[derive(Debug, Clone)] -pub struct TransactionParameters { +pub struct CommonTxnParams { + pub(crate) state: TransactionState, + + pub(crate) xact_gid: Option, + isolation_level: TransactionIsolationLevel, read_only: bool, deferrable: bool, } -impl TransactionParameters { +impl CommonTxnParams { pub fn new( isolation_level: TransactionIsolationLevel, read_only: bool, deferrable: bool, ) -> Self { Self { + state: TransactionState::Idle, + xact_gid: None, isolation_level, read_only, deferrable, @@ -106,96 +108,12 @@ impl TransactionParameters { } } -impl Default for TransactionParameters { +impl Default for CommonTxnParams { fn default() -> Self { Self::new(TransactionIsolationLevel::ReadCommitted, false, false) } } -fn get_default_transaction_isolation(sparams: &ServerParameters) -> TransactionIsolationLevel { - // Can unwrap because we set it in the constructor - if let Some(isolation_level) = sparams.parameters.get("default_transaction_isolation") { - return match isolation_level.to_lowercase().as_str() { - "read committed" => TransactionIsolationLevel::ReadCommitted, - "repeatable read" => TransactionIsolationLevel::RepeatableRead, - "serializable" => TransactionIsolationLevel::Serializable, - "read uncommitted" => TransactionIsolationLevel::ReadUncommitted, - _ => TransactionIsolationLevel::ReadCommitted, - }; - } - TransactionIsolationLevel::ReadCommitted -} - -fn get_default_transaction_read_only(sparams: &ServerParameters) -> bool { - if let Some(is_readonly) = sparams.parameters.get("default_transaction_read_only") { - return !is_readonly.to_lowercase().eq("off"); - } - false -} - -fn get_default_transaction_deferrable(sparams: &ServerParameters) -> bool { - if let Some(deferrable) = sparams.parameters.get("default_transaction_deferrable") { - return !deferrable.to_lowercase().eq("off"); - } - false -} - -fn get_default_transaction_parameters(sparams: &ServerParameters) -> TransactionParameters { - TransactionParameters::new( - get_default_transaction_isolation(sparams), - get_default_transaction_read_only(sparams), - get_default_transaction_deferrable(sparams), - ) -} - -pub fn server_default_transaction_parameters(server: &Server) -> TransactionParameters { - get_default_transaction_parameters(&server.server_parameters) -} - -/// Sends some queries to the server to sync the given pramaters specified by 'keys'. -pub async fn sync_given_parameter_keys(server: &mut Server, keys: &[String]) -> Result<(), Error> { - let mut key_values = HashMap::new(); - for key in keys { - if let Some(value) = server.server_parameters.parameters.get(key) { - key_values.insert(key.clone(), value.clone()); - } - } - sync_given_parameter_key_values(server, &key_values).await -} - -/// Sends some queries to the server to sync the given pramaters specified by 'key_values'. -pub async fn sync_given_parameter_key_values( - server: &mut Server, - key_values: &HashMap, -) -> Result<(), Error> { - let mut query = String::from(""); - - for (key, value) in key_values { - query.push_str(&format!("SET {} TO '{}';", key, value)); - } - - let res = server.query(&query).await; - - server.cleanup_state.reset(); - - match res { - Ok(None) => Ok(()), - Ok(Some(err_res)) => { - warn!( - "Error while syncing parameters (was dropped): {:?}", - err_res - ); - Ok(()) - } - Err(err) => Err(err), - } -} - -/// Returnes true if the given server is in a failed transaction state. -pub fn in_failed_transaction(server: &Server) -> bool { - server.transaction_metadata.is_in_failed_transaction() -} - /// The various states that a server transaction can be in. #[derive(Debug, PartialEq, Clone, Copy)] pub enum TransactionState { @@ -208,300 +126,209 @@ pub enum TransactionState { } /// The metadata of a server transaction. -#[derive(Debug, Clone)] -pub struct TransactionMetaData { - pub(crate) state: TransactionState, - - xact_gid: Option, - snapshot: Option, - - begin_statement: Option, - commit_statement: Option, - abort_statement: Option, +#[derive(Default, Debug, Clone)] +pub struct ServerTxnMetaData { + is_prepared: bool, - pub params: TransactionParameters, + pub params: CommonTxnParams, } -impl TransactionMetaData { - pub fn set_state(&mut self, state: TransactionState) { - match self.state { +pub fn set_state_helper(from_state: &mut TransactionState, to_state: TransactionState) { + match *from_state { + TransactionState::Idle => { + *from_state = to_state; + } + TransactionState::InTransaction => match to_state { TransactionState::Idle => { - self.state = state; + warn!("Cannot go back to idle from a transaction."); } - TransactionState::InTransaction => match state { - TransactionState::Idle => { - warn!("Cannot go back to idle from a transaction."); - } - _ => { - self.state = state; - } - }, - TransactionState::InFailedTransaction => match state { - TransactionState::Idle => { - warn!("Cannot go back to idle from a failed transaction."); - } - TransactionState::InTransaction => { - warn!("Cannot go back to a transaction from a failed transaction.") - } - _ => { - self.state = state; - } - }, - } + _ => { + *from_state = to_state; + } + }, + TransactionState::InFailedTransaction => match to_state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a failed transaction."); + } + TransactionState::InTransaction => { + warn!("Cannot go back to a transaction from a failed transaction.") + } + _ => { + *from_state = to_state; + } + }, + } +} + +impl ServerTxnMetaData { + pub fn set_state(&mut self, state: TransactionState) { + set_state_helper(&mut self.params.state, state); } pub fn state(&self) -> TransactionState { - self.state + self.params.state } pub fn is_idle(&self) -> bool { - self.state == TransactionState::Idle + self.params.state == TransactionState::Idle } pub fn is_in_transaction(&self) -> bool { - self.state == TransactionState::InTransaction + self.params.state == TransactionState::InTransaction } pub fn is_in_failed_transaction(&self) -> bool { - self.state == TransactionState::InFailedTransaction + self.params.state == TransactionState::InFailedTransaction } pub fn set_xact_gid(&mut self, xact_gid: Option) { - self.xact_gid = xact_gid; + self.params.xact_gid = xact_gid; } pub fn get_xact_gid(&self) -> Option { - self.xact_gid.clone() - } - - pub fn set_snapshot(&mut self, snapshot: Option) { - self.snapshot = snapshot; - } - - pub fn get_snapshot(&self) -> Option { - self.snapshot.clone() - } - - pub fn set_begin_statement(&mut self, begin_statement: Option) { - self.begin_statement = begin_statement; + self.params.xact_gid.clone() } - pub fn get_begin_statement(&self) -> Option<&Statement> { - self.begin_statement.as_ref() + pub fn set_prepared(&mut self, is_prepared: bool) { + self.is_prepared = is_prepared; } - pub fn set_commit_statement(&mut self, commit_statement: Option) { - self.commit_statement = commit_statement; + pub fn has_done_prepare_transaction(&self) -> bool { + self.is_prepared } +} - pub fn get_commit_statement(&self) -> Option<&Statement> { - self.commit_statement.as_ref() +impl ServerParameters { + fn get_default_transaction_isolation(&self) -> TransactionIsolationLevel { + // Can unwrap because we set it in the constructor + if let Some(isolation_level) = self.parameters.get("default_transaction_isolation") { + return match isolation_level.to_lowercase().as_str() { + "read committed" => TransactionIsolationLevel::ReadCommitted, + "repeatable read" => TransactionIsolationLevel::RepeatableRead, + "serializable" => TransactionIsolationLevel::Serializable, + "read uncommitted" => TransactionIsolationLevel::ReadUncommitted, + _ => TransactionIsolationLevel::ReadCommitted, + }; + } + TransactionIsolationLevel::ReadCommitted } - pub fn set_abort_statement(&mut self, abort_statement: Option) { - self.abort_statement = abort_statement; + fn get_default_transaction_read_only(&self) -> bool { + if let Some(is_readonly) = self.parameters.get("default_transaction_read_only") { + return !is_readonly.to_lowercase().eq("off"); + } + false } - pub fn get_abort_statement(&self) -> Option<&Statement> { - self.abort_statement.as_ref() + fn get_default_transaction_deferrable(&self) -> bool { + if let Some(deferrable) = self.parameters.get("default_transaction_deferrable") { + return !deferrable.to_lowercase().eq("off"); + } + false } - pub fn is_transaction_started(&self) -> bool { - self.begin_statement.is_some() + fn get_default_transaction_parameters(&self) -> CommonTxnParams { + CommonTxnParams::new( + self.get_default_transaction_isolation(), + self.get_default_transaction_read_only(), + self.get_default_transaction_deferrable(), + ) } } -impl Default for TransactionMetaData { - fn default() -> Self { - Self { - state: TransactionState::Idle, - xact_gid: None, - snapshot: None, - begin_statement: None, - commit_statement: None, - abort_statement: None, - params: TransactionParameters::default(), - } +impl Server { + pub fn server_default_transaction_parameters(&self) -> CommonTxnParams { + self.server_parameters.get_default_transaction_parameters() } -} -/// Represents a read-write conflict. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct RWConflict { - pub source_gid: String, - pub gid_in: String, - pub gid_out: String, -} - -impl RWConflict { - pub fn new(source_gid: String, gid_in: String, gid_out: String) -> Self { - Self { - source_gid, - gid_in, - gid_out, + /// Sends some queries to the server to sync the given pramaters specified by 'keys'. + pub async fn sync_given_parameter_keys(&mut self, keys: &[String]) -> Result<(), Error> { + let mut key_values = HashMap::new(); + for key in keys { + if let Some(value) = self.server_parameters.parameters.get(key) { + key_values.insert(key.clone(), value.clone()); + } } + self.sync_given_parameter_key_values(&key_values).await } -} -/// Execute an arbitrary query against the server. -/// It will use the simple query protocol. -/// Result will be returned, so this is useful for things like `SELECT`. -async fn query_with_response( - server: &mut Server, - query: &str, -) -> Result, Error> { - debug!( - "Running `{}` on server {} and capture response.", - query, server.address - ); - - let query = simple_query(query); - - server.send(&query).await?; - - // Read all data the server has to offer, which can be multiple messages - // buffered in 8196 bytes chunks. - let mut response = server.recv(None).await?; - - let query_response = match response[0] { - b'T' => { - let row_desc = RowDescription::decode(&mut response)?.unwrap(); - - let mut data_rows = Vec::new(); - - loop { - if response.remaining() == 0 { - if server.is_data_available() { - response = server.recv(None).await?; - } else { - break; - } - } - if response[0] == b'C' { - break; - } - - let data_row = DataRow::decode(&mut response)?.unwrap(); - data_rows.push(data_row); - } - - QueryResponse::new(row_desc, data_rows) - } + /// Sends some queries to the server to sync the given pramaters specified by 'key_values'. + pub async fn sync_given_parameter_key_values( + &mut self, + key_values: &HashMap, + ) -> Result<(), Error> { + let mut query = String::from(""); - b'E' => { - let err = ErrorResponse::decode(&mut response)?.unwrap(); - return Ok(Either::Right(err)); + for (key, value) in key_values { + query.push_str(&format!("SET {} TO '{}';", key, value)); } - _ => return Err(Error::ServerError), - }; + let res = self.query(&query).await; - Ok(Either::Left(query_response)) -} - -/// Captures the snapshot from the server. -pub async fn acquire_xact_snapshot( - server: &mut Server, -) -> Result, Error> { - let qres = query_with_response(server, "select pg_export_snapshot()").await?; + self.cleanup_state.reset(); - if qres.is_right() { - return Ok(Either::Right(qres.right().unwrap())); + match res { + Ok(_) => Ok(()), + Err(Error::ErrorResponse(err_res)) => { + warn!( + "Error while syncing parameters (was dropped): {:?}", + err_res + ); + Ok(()) + } + Err(err) => Err(err), + } } - let qres = qres.left().unwrap(); - - let qres_rows: &[DataRow] = qres.data_rows(); - assert!(qres.row_desc().fields().len() == 1); - assert!(qres_rows.len() == 1); - if let Some(snapshot) = qres_rows[0].fields().get(0).unwrap() { - let snapshot = std::str::from_utf8(snapshot).unwrap().to_string(); - - debug!("Got snapshot: {}", snapshot); - - server - .transaction_metadata - .set_snapshot(Some(snapshot.clone())); - - Ok(Either::Left(snapshot)) - } else { - Err(Error::BadQuery( - "Could not get snapshot from server".to_string(), - )) + /// Returnes true if the given server is in a failed transaction state. + pub fn in_failed_transaction(&self) -> bool { + self.transaction_metadata.is_in_failed_transaction() } -} -/// Sets the snapshot to the server (based on a previous snapshot acquired by the first server). -pub async fn assign_xact_snapshot( - server: &mut Server, - snapshot: &str, -) -> Result, Error> { - server - .query(&format!("set transaction snapshot '{snapshot}'")) - .await -} + /// Sets the GID on the server. If we are in serializable mode, we need to register the GID to + /// the remote postgres instance, too. + pub async fn assign_xact_gid(&mut self, gid: &str) -> Result<(), Error> { + self.transaction_metadata + .set_xact_gid(Some(gid.to_string())); + Ok(()) + } -/// Sets the GID on the server. If we are in serializable mode, we need to register the GID to -/// the remote postgres instance, too. -pub async fn assign_xact_gid( - server: &mut Server, - gid: &str, -) -> Result, Error> { - server - .transaction_metadata - .set_xact_gid(Some(gid.to_string())); - Ok(None) -} + pub async fn local_server_prepare_transaction(&mut self) -> Result<(), Error> { + debug!( + "Called local_server_prepare_transaction on {}", + self.address + ); -pub async fn local_server_prepare_transaction( - server: &mut Server, -) -> Result, Error> { - debug!( - "Called local_server_prepare_transaction on {}", - server.address, - ); + let xact_gid = self.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery(format!( + "There is no GID assigned to the current transaction while it's requested to be \ + prepared to commit on the server ({}).", + self.address() + ))); + } + let xact_gid = xact_gid.unwrap(); - let xact_gid = server.transaction_metadata.get_xact_gid(); - if xact_gid.is_none() { - return Err(Error::BadQuery(format!( - "There is no GID assigned to the current transaction while it's requested to be \ - prepared to commit on the server ({}).", - server.address() - ))); - } - let xact_gid = xact_gid.unwrap(); + self.query(&format!("PREPARE TRANSACTION '{}'", xact_gid)) + .await?; - let qres = server - .query(&format!("PREPARE TRANSACTION '{}'", xact_gid)) - .await?; - if qres.is_some() { - return Ok(qres); + self.transaction_metadata.set_prepared(true); + Ok(()) } - Ok(None) -} - -pub async fn local_server_commit_prepared( - server: &mut Server, -) -> Result, Error> { - debug!("Called local_server_commit_prepared on {}.", server.address); + pub async fn local_server_commit_prepared(&mut self) -> Result<(), Error> { + debug!("Called local_server_commit_prepared on {}.", self.address); - let xact_gid = server.transaction_metadata.get_xact_gid(); - if xact_gid.is_none() { - return Err(Error::BadQuery( - "The current connection is not attached to a \ - transaction while it's requested to be prepared to commit." - .to_string(), - )); - } - let xact_gid = xact_gid.unwrap(); + let xact_gid = self.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery( + "The current connection is not attached to a \ + transaction while it's requested to be prepared to commit." + .to_string(), + )); + } + let xact_gid = xact_gid.unwrap(); - let qres = server - .query(&format!("COMMIT PREPARED '{}'", xact_gid)) - .await?; - if qres.is_some() { - return Ok(qres); + self.query(&format!("COMMIT PREPARED '{}'", xact_gid)).await } - - Ok(None) } From 4d7c4835b19081bae92da581bfc0ee21ff2e8e9b Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Mon, 16 Oct 2023 02:23:31 -0700 Subject: [PATCH 14/15] Fixed the build failure issue + Nitz. --- src/client.rs | 35 +++++++++++++++++++++++----------- src/client_xact.rs | 12 ++++++++++-- src/query_router.rs | 21 +++++++++++++++++++-- src/server_xact.rs | 46 +++++++++++++++++++++++---------------------- 4 files changed, 77 insertions(+), 37 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8af3d386..04d9f4e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -854,10 +854,12 @@ where initial_parsed_ast = Some(ast); } Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); + if error != Error::UnsupportedStatement { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } } } } @@ -1309,8 +1311,6 @@ where ControlFlow::Break(ast) => ast, }; - self.assign_client_transaction_state(all_conns); - if all_conns.is_empty() || self.is_transparent_mode() { let current_shard = query_router.shard(); @@ -1554,6 +1554,8 @@ where server.sync_parameters(&self.server_parameters).await?; } + self.assign_client_transaction_state(all_conns); + let is_distributed_xact = all_conns.len() > 1; let server_key = query_router.shard().unwrap_or(0); let conn = all_conns.get_mut(&server_key).unwrap(); @@ -2211,17 +2213,28 @@ fn parse_ast( ) -> Option> { // We don't want to parse again if we already parsed it as the initial message match *initial_parsed_ast { - Some(_) => Some(initial_parsed_ast.take().unwrap()), + Some(_) => { + let parsed_ast = initial_parsed_ast.take().unwrap(); + // if 'parsed_ast' is empty, it means that there was a failed + // attempt to parse the query as a custom command, earlier above. + if parsed_ast.is_empty() { + None + } else { + Some(parsed_ast) + } + } None => match query_router.parse(message) { Ok(ast) => { let _ = query_router.infer(&ast); Some(ast) } Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); + if error != Error::UnsupportedStatement { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } None } }, diff --git a/src/client_xact.rs b/src/client_xact.rs index 86a4b2ef..e5fc59fe 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -29,7 +29,7 @@ pub struct ClientTxnMetaData { impl ClientTxnMetaData { pub fn set_state(&mut self, state: TransactionState) { - set_state_helper(&mut self.params.state, state); + CommonTxnParams::set_state_helper(&mut self.params.state, state); } pub fn state(&self) -> TransactionState { @@ -198,13 +198,21 @@ where ); // Set the transaction parameters on the first server. - server.transaction_metadata_mut().params = self.xact_info.params.clone(); + self.set_transaction_params_to_server(server); } else { // If it's not a BEGIN, then it's an irrecoverable error. panic!("The statement is not a BEGIN statement."); } } + fn set_transaction_params_to_server(&mut self, server: &mut Server) { + let server_params = &mut server.transaction_metadata_mut().params; + + server_params.set_isolation_level(self.xact_info.params.get_isolation_level()); + server_params.set_read_only(self.xact_info.params.is_read_only()); + server_params.set_deferrable(self.xact_info.params.is_deferrable()); + } + /// This function performs a distribted abort/commit if necessary, and also resets the transaction /// state. This is suppoed to be called before exiting the transaction loop. At that point, if /// either an abort or commit statement is set, we need to perform a distributed abort/commit. This diff --git a/src/query_router.rs b/src/query_router.rs index 8b451dd3..4b6873dc 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -338,6 +338,9 @@ impl QueryRouter { Some((command, value)) } + const UNSUPPORTED_STATEMENTS_FOR_PARSING: [&'static str; 4] = + ["COPY", "SET", "TRUNCATE", "VACUUM"]; + pub fn parse(&self, message: &BytesMut) -> Result, Error> { let mut message_cursor = Cursor::new(message); @@ -380,8 +383,22 @@ impl QueryRouter { match Parser::parse_sql(&PostgreSqlDialect {}, &query) { Ok(ast) => Ok(ast), Err(err) => { - debug!("{}: {}", err, query); - Err(Error::QueryRouterParserError(err.to_string())) + let qry_upper = query.to_ascii_uppercase(); + + // Check for unsupported statements to avoid producing a warning. + // Note 1: this is not a complete list of unsupported statements. + // Note 2: we do not check for unsupported statements before going through the + // parser, as the plugin system might be able to handle them, once sqlparser + // is able to correctly parse these (rather valid) queries. + if Self::UNSUPPORTED_STATEMENTS_FOR_PARSING + .iter() + .any(|s| qry_upper.starts_with(s)) + { + Err(Error::UnsupportedStatement) + } else { + debug!("{}: {}", err, query); + Err(Error::QueryRouterParserError(err.to_string())) + } } } } diff --git a/src/server_xact.rs b/src/server_xact.rs index a9c73655..ee2ff3a2 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -133,36 +133,38 @@ pub struct ServerTxnMetaData { pub params: CommonTxnParams, } -pub fn set_state_helper(from_state: &mut TransactionState, to_state: TransactionState) { - match *from_state { - TransactionState::Idle => { - *from_state = to_state; - } - TransactionState::InTransaction => match to_state { - TransactionState::Idle => { - warn!("Cannot go back to idle from a transaction."); - } - _ => { - *from_state = to_state; - } - }, - TransactionState::InFailedTransaction => match to_state { +impl CommonTxnParams { + pub fn set_state_helper(from_state: &mut TransactionState, to_state: TransactionState) { + match *from_state { TransactionState::Idle => { - warn!("Cannot go back to idle from a failed transaction."); - } - TransactionState::InTransaction => { - warn!("Cannot go back to a transaction from a failed transaction.") - } - _ => { *from_state = to_state; } - }, + TransactionState::InTransaction => match to_state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a transaction."); + } + _ => { + *from_state = to_state; + } + }, + TransactionState::InFailedTransaction => match to_state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a failed transaction."); + } + TransactionState::InTransaction => { + warn!("Cannot go back to a transaction from a failed transaction.") + } + _ => { + *from_state = to_state; + } + }, + } } } impl ServerTxnMetaData { pub fn set_state(&mut self, state: TransactionState) { - set_state_helper(&mut self.params.state, state); + CommonTxnParams::set_state_helper(&mut self.params.state, state); } pub fn state(&self) -> TransactionState { From 46bd113ae5a1403b9244b4e2c9267dd3ab2836c6 Mon Sep 17 00:00:00 2001 From: Mohammad Dashti Date: Mon, 16 Oct 2023 02:52:17 -0700 Subject: [PATCH 15/15] Fixed `set_state` implementation for `ServerTxnMetaData` + Added a sanity chec to `encode` + Nitz. --- src/client_xact.rs | 25 ++++++++++++++++++++++++- src/query_messages.rs | 9 +++++++-- src/server.rs | 8 ++++---- src/server_xact.rs | 31 +------------------------------ 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/client_xact.rs b/src/client_xact.rs index e5fc59fe..36e7737d 100644 --- a/src/client_xact.rs +++ b/src/client_xact.rs @@ -29,7 +29,30 @@ pub struct ClientTxnMetaData { impl ClientTxnMetaData { pub fn set_state(&mut self, state: TransactionState) { - CommonTxnParams::set_state_helper(&mut self.params.state, state); + match self.params.state { + TransactionState::Idle => { + self.params.state = state; + } + TransactionState::InTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a transaction."); + } + _ => { + self.params.state = state; + } + }, + TransactionState::InFailedTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a failed transaction."); + } + TransactionState::InTransaction => { + warn!("Cannot go back to a transaction from a failed transaction.") + } + _ => { + self.params.state = state; + } + }, + } } pub fn state(&self) -> TransactionState { diff --git a/src/query_messages.rs b/src/query_messages.rs index fb354d13..16575bca 100644 --- a/src/query_messages.rs +++ b/src/query_messages.rs @@ -114,8 +114,13 @@ pub trait Message: Sized { buf.put_u8(mt); } - buf.put_i32(self.message_length() as i32); - self.encode_body(buf) + let message_length = self.message_length(); + let original_buf_len = buf.len(); + buf.put_i32(message_length as i32); + let result = self.encode_body(buf); + assert_eq!(buf.len() - original_buf_len, message_length); + + result } /// Default implementation for decoding message. diff --git a/src/server.rs b/src/server.rs index 20b1a8f0..22e58257 100644 --- a/src/server.rs +++ b/src/server.rs @@ -931,21 +931,21 @@ impl Server { 'Z' => { let transaction_state = message.get_u8() as char; - let params = &mut self.transaction_metadata.params; + let metadata = &mut self.transaction_metadata; match transaction_state { // In transaction. 'T' => { - params.state = TransactionState::InTransaction; + metadata.set_state(TransactionState::InTransaction); } // Idle, transaction over. 'I' => { - params.state = TransactionState::Idle; + metadata.set_state(TransactionState::Idle); } // Some error occurred, the transaction was rolled back. 'E' => { - params.state = TransactionState::InFailedTransaction; + metadata.set_state(TransactionState::InFailedTransaction); } // Something totally unexpected, this is not a Postgres server we know. diff --git a/src/server_xact.rs b/src/server_xact.rs index ee2ff3a2..0927af79 100644 --- a/src/server_xact.rs +++ b/src/server_xact.rs @@ -133,38 +133,9 @@ pub struct ServerTxnMetaData { pub params: CommonTxnParams, } -impl CommonTxnParams { - pub fn set_state_helper(from_state: &mut TransactionState, to_state: TransactionState) { - match *from_state { - TransactionState::Idle => { - *from_state = to_state; - } - TransactionState::InTransaction => match to_state { - TransactionState::Idle => { - warn!("Cannot go back to idle from a transaction."); - } - _ => { - *from_state = to_state; - } - }, - TransactionState::InFailedTransaction => match to_state { - TransactionState::Idle => { - warn!("Cannot go back to idle from a failed transaction."); - } - TransactionState::InTransaction => { - warn!("Cannot go back to a transaction from a failed transaction.") - } - _ => { - *from_state = to_state; - } - }, - } - } -} - impl ServerTxnMetaData { pub fn set_state(&mut self, state: TransactionState) { - CommonTxnParams::set_state_helper(&mut self.params.state, state); + self.params.state = state; } pub fn state(&self) -> TransactionState {