Skip to content

Commit 4b78af9

Browse files
authored
Implement Close for prepared statements (#482)
* Partial support for Close * Close * respect config value * prepared spec * Hmm * Print cache size
1 parent 73500c0 commit 4b78af9

File tree

8 files changed

+269
-8
lines changed

8 files changed

+269
-8
lines changed

pgcat.toml

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ tcp_keepalives_interval = 5
6363
# Handle prepared statements.
6464
prepared_statements = true
6565

66+
# Prepared statements server cache size.
67+
prepared_statements_cache_size = 500
68+
6669
# Path to TLS Certificate file to use for TLS connections
6770
# tls_certificate = ".circleci/server.cert"
6871
# Path to TLS private key file to use for TLS connections

src/admin.rs

+5
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ where
701701
("age_seconds", DataType::Numeric),
702702
("prepare_cache_hit", DataType::Numeric),
703703
("prepare_cache_miss", DataType::Numeric),
704+
("prepare_cache_size", DataType::Numeric),
704705
];
705706

706707
let new_map = get_server_stats();
@@ -732,6 +733,10 @@ where
732733
.prepared_miss_count
733734
.load(Ordering::Relaxed)
734735
.to_string(),
736+
server
737+
.prepared_cache_size
738+
.load(Ordering::Relaxed)
739+
.to_string(),
735740
];
736741

737742
res.put(data_row(&row));

src/client.rs

+49-1
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,19 @@ where
906906
return Ok(());
907907
}
908908

909+
// Close (F)
910+
'C' => {
911+
if prepared_statements_enabled {
912+
let close: Close = (&message).try_into()?;
913+
914+
if close.is_prepared_statement() && !close.anonymous() {
915+
self.prepared_statements.remove(&close.name);
916+
write_all_flush(&mut self.write, &close_complete()).await?;
917+
continue;
918+
}
919+
}
920+
}
921+
909922
_ => (),
910923
}
911924

@@ -1130,7 +1143,17 @@ where
11301143
} else {
11311144
// The statement is not prepared on the server, so we need to prepare it.
11321145
if server.should_prepare(&statement.name) {
1133-
server.prepare(statement).await?;
1146+
match server.prepare(statement).await {
1147+
Ok(_) => (),
1148+
Err(err) => {
1149+
pool.ban(
1150+
&address,
1151+
BanReason::MessageSendFailed,
1152+
Some(&self.stats),
1153+
);
1154+
return Err(err);
1155+
}
1156+
}
11341157
}
11351158
}
11361159

@@ -1251,6 +1274,10 @@ where
12511274
self.stats.disconnect();
12521275
self.release();
12531276

1277+
if prepared_statements_enabled {
1278+
server.maintain_cache().await?;
1279+
}
1280+
12541281
return Ok(());
12551282
}
12561283

@@ -1300,6 +1327,21 @@ where
13001327

13011328
// Close the prepared statement.
13021329
'C' => {
1330+
if prepared_statements_enabled {
1331+
let close: Close = (&message).try_into()?;
1332+
1333+
if close.is_prepared_statement() && !close.anonymous() {
1334+
match self.prepared_statements.get(&close.name) {
1335+
Some(parse) => {
1336+
server.will_close(&parse.generated_name);
1337+
}
1338+
1339+
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
1340+
None => (),
1341+
};
1342+
}
1343+
}
1344+
13031345
self.buffer.put(&message[..]);
13041346
}
13051347

@@ -1433,7 +1475,13 @@ where
14331475

14341476
// The server is no longer bound to us, we can't cancel it's queries anymore.
14351477
debug!("Releasing server back into the pool");
1478+
14361479
server.checkin_cleanup().await?;
1480+
1481+
if prepared_statements_enabled {
1482+
server.maintain_cache().await?;
1483+
}
1484+
14371485
server.stats().idle();
14381486
self.connected_to_server = false;
14391487

src/config.rs

+20-4
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ pub struct General {
323323

324324
#[serde(default)]
325325
pub prepared_statements: bool,
326+
327+
#[serde(default = "General::default_prepared_statements_cache_size")]
328+
pub prepared_statements_cache_size: usize,
326329
}
327330

328331
impl General {
@@ -400,6 +403,10 @@ impl General {
400403
pub fn default_server_round_robin() -> bool {
401404
true
402405
}
406+
407+
pub fn default_prepared_statements_cache_size() -> usize {
408+
500
409+
}
403410
}
404411

405412
impl Default for General {
@@ -438,6 +445,7 @@ impl Default for General {
438445
server_round_robin: false,
439446
validate_config: true,
440447
prepared_statements: false,
448+
prepared_statements_cache_size: 500,
441449
}
442450
}
443451
}
@@ -1020,6 +1028,12 @@ impl Config {
10201028
self.general.verify_server_certificate
10211029
);
10221030
info!("Prepared statements: {}", self.general.prepared_statements);
1031+
if self.general.prepared_statements {
1032+
info!(
1033+
"Prepared statements server cache size: {}",
1034+
self.general.prepared_statements_cache_size
1035+
);
1036+
}
10231037
info!(
10241038
"Plugins: {}",
10251039
match self.plugins {
@@ -1239,13 +1253,15 @@ pub fn get_config() -> Config {
12391253
}
12401254

12411255
pub fn get_idle_client_in_transaction_timeout() -> u64 {
1242-
(*(*CONFIG.load()))
1243-
.general
1244-
.idle_client_in_transaction_timeout
1256+
CONFIG.load().general.idle_client_in_transaction_timeout
12451257
}
12461258

12471259
pub fn get_prepared_statements() -> bool {
1248-
(*(*CONFIG.load())).general.prepared_statements
1260+
CONFIG.load().general.prepared_statements
1261+
}
1262+
1263+
pub fn get_prepared_statements_cache_size() -> usize {
1264+
CONFIG.load().general.prepared_statements_cache_size
12491265
}
12501266

12511267
/// Parse the configuration file located at the path.

src/messages.rs

+79-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/// Helper functions to send one-off protocol messages
22
/// and handle TcpStream (TCP socket).
33
use bytes::{Buf, BufMut, BytesMut};
4-
use log::error;
4+
use log::{debug, error};
55
use md5::{Digest, Md5};
66
use socket2::{SockRef, TcpKeepalive};
77
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -976,6 +976,84 @@ impl Describe {
976976
}
977977
}
978978

979+
/// Close (F) message.
980+
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
981+
#[derive(Clone, Debug)]
982+
pub struct Close {
983+
code: char,
984+
#[allow(dead_code)]
985+
len: i32,
986+
close_type: char,
987+
pub name: String,
988+
}
989+
990+
impl TryFrom<&BytesMut> for Close {
991+
type Error = Error;
992+
993+
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
994+
let mut cursor = Cursor::new(bytes);
995+
let code = cursor.get_u8() as char;
996+
let len = cursor.get_i32();
997+
let close_type = cursor.get_u8() as char;
998+
let name = cursor.read_string()?;
999+
1000+
Ok(Close {
1001+
code,
1002+
len,
1003+
close_type,
1004+
name,
1005+
})
1006+
}
1007+
}
1008+
1009+
impl TryFrom<Close> for BytesMut {
1010+
type Error = Error;
1011+
1012+
fn try_from(close: Close) -> Result<BytesMut, Error> {
1013+
debug!("Close: {:?}", close);
1014+
1015+
let mut bytes = BytesMut::new();
1016+
let name_binding = CString::new(close.name)?;
1017+
let name = name_binding.as_bytes_with_nul();
1018+
let len = 4 + 1 + name.len();
1019+
1020+
bytes.put_u8(close.code as u8);
1021+
bytes.put_i32(len as i32);
1022+
bytes.put_u8(close.close_type as u8);
1023+
bytes.put_slice(name);
1024+
1025+
Ok(bytes)
1026+
}
1027+
}
1028+
1029+
impl Close {
1030+
pub fn new(name: &str) -> Close {
1031+
let name = name.to_string();
1032+
1033+
Close {
1034+
code: 'C',
1035+
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
1036+
close_type: 'S',
1037+
name,
1038+
}
1039+
}
1040+
1041+
pub fn is_prepared_statement(&self) -> bool {
1042+
self.close_type == 'S'
1043+
}
1044+
1045+
pub fn anonymous(&self) -> bool {
1046+
self.name.is_empty()
1047+
}
1048+
}
1049+
1050+
pub fn close_complete() -> BytesMut {
1051+
let mut bytes = BytesMut::new();
1052+
bytes.put_u8(b'3');
1053+
bytes.put_i32(4);
1054+
bytes
1055+
}
1056+
9791057
pub fn prepared_statement_name() -> String {
9801058
format!(
9811059
"P_{}",

src/server.rs

+74-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use tokio::net::TcpStream;
1515
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
1616
use tokio_rustls::{client::TlsStream, TlsConnector};
1717

18-
use crate::config::{get_config, Address, User};
18+
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
1919
use crate::constants::*;
2020
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
2121
use crate::errors::{Error, ServerIdentifier};
@@ -914,12 +914,16 @@ impl Server {
914914
Ok(bytes)
915915
}
916916

917+
/// Add the prepared statement to being tracked by this server.
918+
/// The client is processing data that will create a prepared statement on this server.
917919
pub fn will_prepare(&mut self, name: &str) {
918920
debug!("Will prepare `{}`", name);
919921

920922
self.prepared_statements.insert(name.to_string());
923+
self.stats.prepared_cache_add();
921924
}
922925

926+
/// Check if we should prepare a statement on the server.
923927
pub fn should_prepare(&self, name: &str) -> bool {
924928
let should_prepare = !self.prepared_statements.contains(name);
925929

@@ -934,6 +938,7 @@ impl Server {
934938
should_prepare
935939
}
936940

941+
/// Create a prepared statement on the server.
937942
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
938943
debug!("Preparing `{}`", parse.name);
939944

@@ -942,15 +947,82 @@ impl Server {
942947
self.send(&flush()).await?;
943948

944949
// Read and discard ParseComplete (B)
945-
let _ = read_message(&mut self.stream).await?;
950+
match read_message(&mut self.stream).await {
951+
Ok(_) => (),
952+
Err(err) => {
953+
self.bad = true;
954+
return Err(err);
955+
}
956+
}
946957

947958
self.prepared_statements.insert(parse.name.to_string());
959+
self.stats.prepared_cache_add();
948960

949961
debug!("Prepared `{}`", parse.name);
950962

951963
Ok(())
952964
}
953965

966+
/// Maintain adequate cache size on the server.
967+
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
968+
debug!("Cache maintenance run");
969+
970+
let max_cache_size = get_prepared_statements_cache_size();
971+
let mut names = Vec::new();
972+
973+
while self.prepared_statements.len() >= max_cache_size {
974+
// The prepared statmeents are alphanumerically sorted by the BTree.
975+
// FIFO.
976+
if let Some(name) = self.prepared_statements.pop_last() {
977+
names.push(name);
978+
}
979+
}
980+
981+
self.deallocate(names).await?;
982+
983+
Ok(())
984+
}
985+
986+
/// Remove the prepared statement from being tracked by this server.
987+
/// The client is processing data that will cause the server to close the prepared statement.
988+
pub fn will_close(&mut self, name: &str) {
989+
debug!("Will close `{}`", name);
990+
991+
self.prepared_statements.remove(name);
992+
}
993+
994+
/// Close a prepared statement on the server.
995+
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
996+
for name in &names {
997+
debug!("Deallocating prepared statement `{}`", name);
998+
999+
let close = Close::new(name);
1000+
let bytes: BytesMut = close.try_into()?;
1001+
1002+
self.send(&bytes).await?;
1003+
}
1004+
1005+
self.send(&flush()).await?;
1006+
1007+
// Read and discard CloseComplete (3)
1008+
for name in &names {
1009+
match read_message(&mut self.stream).await {
1010+
Ok(_) => {
1011+
self.prepared_statements.remove(name);
1012+
self.stats.prepared_cache_remove();
1013+
debug!("Closed `{}`", name);
1014+
}
1015+
1016+
Err(err) => {
1017+
self.bad = true;
1018+
return Err(err);
1019+
}
1020+
};
1021+
}
1022+
1023+
Ok(())
1024+
}
1025+
9541026
/// If the server is still inside a transaction.
9551027
/// If the client disconnects while the server is in a transaction, we will clean it up.
9561028
pub fn in_transaction(&self) -> bool {

0 commit comments

Comments
 (0)