Skip to content

Implement Close for prepared statements #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ tcp_keepalives_interval = 5
# Handle prepared statements.
prepared_statements = true

# Prepared statements server cache size.
prepared_statements_cache_size = 500

# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
Expand Down
5 changes: 5 additions & 0 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ where
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
];

let new_map = get_server_stats();
Expand Down Expand Up @@ -732,6 +733,10 @@ where
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
.to_string(),
];

res.put(data_row(&row));
Expand Down
50 changes: 49 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,19 @@ where
return Ok(());
}

// Close (F)
'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;

if close.is_prepared_statement() && !close.anonymous() {
self.prepared_statements.remove(&close.name);
write_all_flush(&mut self.write, &close_complete()).await?;
continue;
}
}
}

_ => (),
}

Expand Down Expand Up @@ -1130,7 +1143,17 @@ where
} else {
// The statement is not prepared on the server, so we need to prepare it.
if server.should_prepare(&statement.name) {
server.prepare(statement).await?;
match server.prepare(statement).await {
Ok(_) => (),
Err(err) => {
pool.ban(
&address,
BanReason::MessageSendFailed,
Some(&self.stats),
);
return Err(err);
}
}
}
}

Expand Down Expand Up @@ -1251,6 +1274,10 @@ where
self.stats.disconnect();
self.release();

if prepared_statements_enabled {
server.maintain_cache().await?;
}

return Ok(());
}

Expand Down Expand Up @@ -1300,6 +1327,21 @@ where

// Close the prepared statement.
'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;

if close.is_prepared_statement() && !close.anonymous() {
match self.prepared_statements.get(&close.name) {
Some(parse) => {
server.will_close(&parse.generated_name);
}

// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
None => (),
};
}
}

self.buffer.put(&message[..]);
}

Expand Down Expand Up @@ -1433,7 +1475,13 @@ where

// The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool");

server.checkin_cleanup().await?;

if prepared_statements_enabled {
server.maintain_cache().await?;
}

server.stats().idle();
self.connected_to_server = false;

Expand Down
24 changes: 20 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ pub struct General {

#[serde(default)]
pub prepared_statements: bool,

#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
}

impl General {
Expand Down Expand Up @@ -400,6 +403,10 @@ impl General {
pub fn default_server_round_robin() -> bool {
true
}

pub fn default_prepared_statements_cache_size() -> usize {
500
}
}

impl Default for General {
Expand Down Expand Up @@ -438,6 +445,7 @@ impl Default for General {
server_round_robin: false,
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
}
}
}
Expand Down Expand Up @@ -1020,6 +1028,12 @@ impl Config {
self.general.verify_server_certificate
);
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!(
"Plugins: {}",
match self.plugins {
Expand Down Expand Up @@ -1239,13 +1253,15 @@ pub fn get_config() -> Config {
}

pub fn get_idle_client_in_transaction_timeout() -> u64 {
(*(*CONFIG.load()))
.general
.idle_client_in_transaction_timeout
CONFIG.load().general.idle_client_in_transaction_timeout
}

pub fn get_prepared_statements() -> bool {
(*(*CONFIG.load())).general.prepared_statements
CONFIG.load().general.prepared_statements
}

pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}

/// Parse the configuration file located at the path.
Expand Down
80 changes: 79 additions & 1 deletion src/messages.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut};
use log::error;
use log::{debug, error};
use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
Expand Down Expand Up @@ -976,6 +976,84 @@ impl Describe {
}
}

/// Close (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Close {
code: char,
#[allow(dead_code)]
len: i32,
close_type: char,
pub name: String,
}

impl TryFrom<&BytesMut> for Close {
type Error = Error;

fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let close_type = cursor.get_u8() as char;
let name = cursor.read_string()?;

Ok(Close {
code,
len,
close_type,
name,
})
}
}

impl TryFrom<Close> for BytesMut {
type Error = Error;

fn try_from(close: Close) -> Result<BytesMut, Error> {
debug!("Close: {:?}", close);

let mut bytes = BytesMut::new();
let name_binding = CString::new(close.name)?;
let name = name_binding.as_bytes_with_nul();
let len = 4 + 1 + name.len();

bytes.put_u8(close.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(close.close_type as u8);
bytes.put_slice(name);

Ok(bytes)
}
}

impl Close {
pub fn new(name: &str) -> Close {
let name = name.to_string();

Close {
code: 'C',
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
close_type: 'S',
name,
}
}

pub fn is_prepared_statement(&self) -> bool {
self.close_type == 'S'
}

pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}

pub fn close_complete() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'3');
bytes.put_i32(4);
bytes
}

pub fn prepared_statement_name() -> String {
format!(
"P_{}",
Expand Down
76 changes: 74 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};

use crate::config::{get_config, Address, User};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
Expand Down Expand Up @@ -914,12 +914,16 @@ impl Server {
Ok(bytes)
}

/// Add the prepared statement to being tracked by this server.
/// The client is processing data that will create a prepared statement on this server.
pub fn will_prepare(&mut self, name: &str) {
debug!("Will prepare `{}`", name);

self.prepared_statements.insert(name.to_string());
self.stats.prepared_cache_add();
}

/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);

Expand All @@ -934,6 +938,7 @@ impl Server {
should_prepare
}

/// Create a prepared statement on the server.
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
debug!("Preparing `{}`", parse.name);

Expand All @@ -942,15 +947,82 @@ impl Server {
self.send(&flush()).await?;

// Read and discard ParseComplete (B)
let _ = read_message(&mut self.stream).await?;
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}

self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add();

debug!("Prepared `{}`", parse.name);

Ok(())
}

/// Maintain adequate cache size on the server.
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
debug!("Cache maintenance run");

let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();

while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
}
}

self.deallocate(names).await?;

Ok(())
}

/// Remove the prepared statement from being tracked by this server.
/// The client is processing data that will cause the server to close the prepared statement.
pub fn will_close(&mut self, name: &str) {
debug!("Will close `{}`", name);

self.prepared_statements.remove(name);
}

/// Close a prepared statement on the server.
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
for name in &names {
debug!("Deallocating prepared statement `{}`", name);

let close = Close::new(name);
let bytes: BytesMut = close.try_into()?;

self.send(&bytes).await?;
}

self.send(&flush()).await?;

// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}

Err(err) => {
self.bad = true;
return Err(err);
}
};
}

Ok(())
}

/// If the server is still inside a transaction.
/// If the client disconnects while the server is in a transaction, we will clean it up.
pub fn in_transaction(&self) -> bool {
Expand Down
Loading