Skip to content

Added the transparent mode #600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Applied cargo clippy suggestions.
  • Loading branch information
mdashti committed Oct 6, 2023
commit a4a554d14ac7fac1f9144482f7b934763a038f40
114 changes: 51 additions & 63 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ pub async fn client_entrypoint(
// Client requested a TLS connection.
Ok((ClientConnectionType::Tls, _)) => {
// TLS settings are configured, will setup TLS now.
if tls_certificate != None {
if tls_certificate.is_some() {
debug!("Accepting TLS request");

let mut yes = BytesMut::new();
Expand Down Expand Up @@ -455,7 +455,7 @@ where
None => "pgcat",
};

let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
let client_identifier = ClientIdentifier::new(application_name, username, pool_name);

let admin = ["pgcat", "pgbouncer"]
.iter()
Expand Down Expand Up @@ -806,7 +806,7 @@ where
let mut will_prepare = false;

let client_identifier = ClientIdentifier::new(
&self.server_parameters.get_application_name(),
self.server_parameters.get_application_name(),
&self.username,
&self.pool_name,
);
Expand Down Expand Up @@ -993,15 +993,11 @@ where
}

// Check on plugin results.
match plugin_output {
Some(PluginOutput::Deny(error)) => {
self.buffer.clear();
error_response(&mut self.write, &error).await?;
plugin_output = None;
continue;
}

_ => (),
if let Some(PluginOutput::Deny(error)) = plugin_output {
self.buffer.clear();
error_response(&mut self.write, &error).await?;
plugin_output = None;
continue;
};

// Check if the pool is paused and wait until it's resumed.
Expand Down Expand Up @@ -1100,50 +1096,47 @@ where

// Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char;
let code = *message.first().unwrap() as char;
let mut ast = None;

match code {
// Query
'Q' => {
// If the first message is a `BEGIN` statement, then we are starting a
// transaction. However, we might not still be on the right shard (as the
// shard might be inferred from the first query). So we parse the query and
// store the `BEGIN` statement. Upon receiving the next query (and possibly
// determining the shard), we will execute the `BEGIN` statement.
if let Some(ast_vec) = initial_parsed_ast.as_ref() {
if Self::is_begin_statement(ast_vec) {
assert_eq!(ast_vec.len(), 1);

initialize_xact_info(self, &ast_vec[0]);

custom_protocol_response_ok_with_state(
&mut self.write,
"BEGIN",
self.xact_info.state(),
)
.await?;
// Query
if code == 'Q' {
// If the first message is a `BEGIN` statement, then we are starting a
// transaction. However, we might not still be on the right shard (as the
// shard might be inferred from the first query). So we parse the query and
// store the `BEGIN` statement. Upon receiving the next query (and possibly
// determining the shard), we will execute the `BEGIN` statement.
if let Some(ast_vec) = initial_parsed_ast.as_ref() {
if Self::is_begin_statement(ast_vec) {
assert_eq!(ast_vec.len(), 1);

continue;
}
initialize_xact_info(self, &ast_vec[0]);

custom_protocol_response_ok_with_state(
&mut self.write,
"BEGIN",
self.xact_info.state(),
)
.await?;

continue;
}
}

if query_router.query_parser_enabled() {
let should_continue;
(should_continue, ast) = self
.parse_ast_helper(
&mut query_router,
&mut initial_parsed_ast,
&message,
&client_identifier,
)
.await?;
if !should_continue {
continue;
}
if query_router.query_parser_enabled() {
let should_continue;
(should_continue, ast) = self
.parse_ast_helper(
&mut query_router,
&mut initial_parsed_ast,
&message,
&client_identifier,
)
.await?;
if !should_continue {
continue;
}
}
_ => (),
};

assign_client_transaction_state(self, &all_conns);
Expand Down Expand Up @@ -1487,10 +1480,8 @@ where
}

if let Some(ast) = &ast {
if is_distributed_xact {
if set_commit_or_abort_statement(self, &ast) {
break;
}
if is_distributed_xact && set_commit_or_abort_statement(self, ast) {
break;
}
}
}
Expand Down Expand Up @@ -1536,7 +1527,7 @@ where
self.stats.transaction();
server
.stats()
.transaction(&self.server_parameters.get_application_name());
.transaction(self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction or transparent modes.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -1613,13 +1604,10 @@ where
let close: Close = (&message).try_into()?;

if close.is_prepared_statement() && !close.anonymous() {
match self.prepared_statements.get(&close.name) {
Some(parse) => {
server.will_close(&parse.generated_name);
}

if let Some(parse) = self.prepared_statements.get(&close.name) {
server.will_close(&parse.generated_name);
} else {
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
None => (),
};
}
}
Expand Down Expand Up @@ -1663,7 +1651,7 @@ where

self.buffer.put(&message[..]);

let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char;

// Almost certainly true
if first_message_code == 'P' && !prepared_statements_enabled {
Expand Down Expand Up @@ -1695,7 +1683,7 @@ where
self.stats.transaction();
server
.stats()
.transaction(&self.server_parameters.get_application_name());
.transaction(self.server_parameters.get_application_name());

// Release server back to the pool if we are in transaction or transparent modes.
// If we are in session mode, we keep the server until the client disconnects.
Expand Down Expand Up @@ -2011,7 +1999,7 @@ where
client_stats.query();
server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64,
&self.server_parameters.get_application_name(),
self.server_parameters.get_application_name(),
);

Ok(())
Expand Down
67 changes: 28 additions & 39 deletions src/client_xact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::client::Client;
use crate::errors::Error;
use crate::query_messages::{ErrorInfo, ErrorResponse, Message};
use bytes::BytesMut;
use core::panic;
use futures::future::join_all;
use itertools::Either;
use log::{debug, warn};
Expand Down Expand Up @@ -127,11 +128,7 @@ where

/// Generates a random GID (i.e., Global transaction ID) for a transaction.
fn generate_xact_gid<S, T>(clnt: &Client<S, T>) -> String {
format!(
"txn_{}_{}",
clnt.addr.to_string(),
Uuid::new_v4().to_string()
)
format!("txn_{}_{}", clnt.addr, Uuid::new_v4())
}

/// Generates a server-specific GID for a transaction. We need this, because it's possible that
Expand All @@ -148,14 +145,12 @@ pub fn assign_client_transaction_state<S, T>(
clnt.xact_info.set_state(if all_conns.is_empty() {
// if there's no server, we're in idle mode.
TransactionState::Idle
} else if is_any_server_in_failed_xact(all_conns) {
// if any server is in failed transaction, we're in failed transaction.
TransactionState::InFailedTransaction
} else {
if is_any_server_in_failed_xact(all_conns) {
// if any server is in failed transaction, we're in failed transaction.
TransactionState::InFailedTransaction
} else {
// if we have at least one server and it is in a transaction, we're in a transaction.
TransactionState::InTransaction
}
// if we have at least one server and it is in a transaction, we're in a transaction.
TransactionState::InTransaction
});
}

Expand All @@ -165,7 +160,7 @@ fn is_any_server_in_failed_xact(
) -> bool {
all_conns
.iter()
.any(|(_, conn)| in_failed_transaction(&*conn.0))
.any(|(_, conn)| in_failed_transaction(&conn.0))
}

/// This function initializes the transaction parameters based on the first BEGIN statement.
Expand All @@ -177,8 +172,8 @@ pub fn initialize_xact_info<S, T>(clnt: &mut Client<S, T>, begin_stmt: &Statemen
clnt.xact_info.set_state(TransactionState::InTransaction);
} else {
// If we were not in a transaction and the first statement is
// not a BEGIN, then it's an irrecovable error.
assert!(false);
// not a BEGIN, then it's an irrecoverable error.
panic!("The first statement in a transaction is not a BEGIN statement.");
}
}

Expand All @@ -200,9 +195,7 @@ pub fn initialize_xact_params<S, T>(
});
}
TransactionMode::IsolationLevel(isolation_level) => {
clnt.xact_info
.params
.set_isolation_level(isolation_level.clone());
clnt.xact_info.params.set_isolation_level(*isolation_level);
}
}
}
Expand All @@ -214,8 +207,8 @@ pub fn initialize_xact_params<S, T>(
// Set the transaction parameters on the first server.
server.transaction_metadata_mut().params = clnt.xact_info.params.clone();
} else {
// If it's not a BEGIN, then it's an irrecovable error.
assert!(false);
// If it's not a BEGIN, then it's an irrecoverable error.
panic!("The statement is not a BEGIN statement.");
}
}

Expand All @@ -238,9 +231,9 @@ where
{
let dist_commit = clnt.xact_info.get_commit_statement();
let dist_abort = clnt.xact_info.get_abort_statement();
Ok(if dist_commit.is_some() || dist_abort.is_some() {
if dist_commit.is_some() || dist_abort.is_some() {
// if either a commit or abort statement is set, we should be in a distributed transaction.
assert!(all_conns.len() > 0);
assert!(!all_conns.is_empty());

let is_chained = should_be_chained(dist_commit, dist_abort);
let dist_commit = dist_commit.map(|stmt| stmt.to_string());
Expand Down Expand Up @@ -316,7 +309,8 @@ where
last_server.address()
);
}
})
}
Ok(())
}

pub fn reset_client_xact<S, T>(clnt: &mut Client<S, T>) {
Expand Down Expand Up @@ -396,7 +390,7 @@ async fn distributed_abort<S, T>(
ServerId,
(bb8::PooledConnection<'_, crate::pool::ServerPool>, Address),
>,
abort_stmt: &String,
abort_stmt: &str,
) -> Result<Option<ErrorResponse>, Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
Expand All @@ -414,7 +408,7 @@ where
set_post_query_state(clnt, server);
server
.stats()
.transaction(&clnt.server_parameters.get_application_name());
.transaction(clnt.server_parameters.get_application_name());
});

for abort_res in abort_results {
Expand All @@ -426,6 +420,7 @@ where
Ok(None)
}

#[allow(clippy::type_complexity)]
async fn distributed_prepare<S, T>(
clnt: &mut Client<S, T>,
all_conns: &mut HashMap<
Expand Down Expand Up @@ -479,10 +474,10 @@ where
/// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort
/// statement on the client.
pub fn set_commit_or_abort_statement<S, T>(clnt: &mut Client<S, T>, ast: &Vec<Statement>) -> bool {
if is_commit_statement(&ast) {
if is_commit_statement(ast) {
clnt.xact_info.set_commit_statement(Some(ast[0].clone()));
true
} else if is_abort_statement(&ast) {
} else if is_abort_statement(ast) {
clnt.xact_info.set_abort_statement(Some(ast[0].clone()));
true
} else {
Expand All @@ -493,12 +488,9 @@ pub fn set_commit_or_abort_statement<S, T>(clnt: &mut Client<S, T>, ast: &Vec<St
/// Returns true if the statement is a commit statement.
fn is_commit_statement(ast: &Vec<Statement>) -> bool {
for statement in ast {
match *statement {
Statement::Commit { .. } => {
assert_eq!(ast.len(), 1);
return true;
}
_ => (),
if let Statement::Commit { .. } = *statement {
assert_eq!(ast.len(), 1);
return true;
}
}
false
Expand All @@ -507,12 +499,9 @@ fn is_commit_statement(ast: &Vec<Statement>) -> bool {
/// Returns true if the statement is an abort statement.
fn is_abort_statement(ast: &Vec<Statement>) -> bool {
for statement in ast {
match *statement {
Statement::Rollback { .. } => {
assert_eq!(ast.len(), 1);
return true;
}
_ => (),
if let Statement::Rollback { .. } = *statement {
assert_eq!(ast.len(), 1);
return true;
}
}
false
Expand Down
1 change: 1 addition & 0 deletions src/query_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ pub struct ErrorInfo {
}

impl ErrorInfo {
#[allow(clippy::too_many_arguments)]
pub fn new(
severity: String,
code: String,
Expand Down
Loading