diff --git a/README.md b/README.md index 63b5ab1a..051a3536 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal | Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. | | Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. | | Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. | +| Password rotation | **Experimental** | Allows to rotate passwords without downtime or using third-party tools to manage Postgres authentication. | ## Status @@ -244,6 +245,12 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic. +### Password rotation + +Password rotation allows to specify multiple passwords for a user, so they can connect to PgCat with multiple credentials. This allows distributed applications to change their configuration (connection strings) gradually and for PgCat to monitor their progression in admin statistics. Once the new secret is deployed everywhere, the old one can be removed from PgCat. + +This also decouples server passwords from client passwords, allowing to change one without necessarily changing the other. + ## License PgCat is free and open source, released under the MIT license. diff --git a/pgcat.toml b/pgcat.toml index 0d883a33..3ef929da 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -58,9 +58,9 @@ tcp_keepalives_count = 5 tcp_keepalives_interval = 5 # Path to TLS Certficate file to use for TLS connections -# tls_certificate = "server.cert" +# tls_certificate = ".circleci/server.cert" # Path to TLS private key file to use for TLS connections -# tls_private_key = "server.key" +# tls_private_key = ".circleci/server.key" # User name to access the virtual administrative database (pgbouncer or pgcat) # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. @@ -129,6 +129,10 @@ connect_timeout = 3000 username = "sharding_user" # Postgresql password password = "sharding_user" + +# Passwords the client can use to connect. Useful for password rotations. +secrets = [ "secret_one", "secret_two" ] + # Maximum number of server connections that can be established for this user # The maximum number of connection from a single Pgcat process to any database in the cluster # is the sum of pool_size across all users. diff --git a/src/admin.rs b/src/admin.rs index 03af755c..4cca2ce7 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -259,6 +259,7 @@ where let columns = vec![ ("database", DataType::Text), ("user", DataType::Text), + ("secret", DataType::Text), ("pool_mode", DataType::Text), ("cl_idle", DataType::Numeric), ("cl_active", DataType::Numeric), @@ -276,10 +277,11 @@ where let mut res = BytesMut::new(); res.put(row_description(&columns)); - for ((_user_pool, _pool), pool_stats) in all_pool_stats { + for (_, pool_stats) in all_pool_stats { let mut row = vec![ pool_stats.database(), pool_stats.user(), + pool_stats.redacted_secret(), pool_stats.pool_mode().to_string(), ]; pool_stats.populate_row(&mut row); @@ -780,7 +782,7 @@ where let database = parts[0]; let user = parts[1]; - match get_pool(database, user) { + match get_pool(database, user, None) { Some(pool) => { pool.pause(); @@ -827,7 +829,7 @@ where let database = parts[0]; let user = parts[1]; - match get_pool(database, user) { + match get_pool(database, user, None) { Some(pool) => { pool.resume(); @@ -895,13 +897,20 @@ where res.put(row_description(&vec![ ("name", DataType::Text), ("pool_mode", DataType::Text), + ("secret", DataType::Text), ])); for (user_pool, pool) in get_all_pools() { let pool_config = &pool.settings; + let redacted_secret = match user_pool.secret { + Some(secret) => format!("****{}", &secret[secret.len() - 4..]), + None => "".to_string(), + }; + res.put(data_row(&vec![ user_pool.user.clone(), pool_config.pool_mode.to_string(), + redacted_secret, ])); } diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 00000000..cd6fae3b --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,405 @@ +//! Module implementing various client authentication mechanisms. +//! +//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection). + +use crate::errors::Error; +use crate::tokio::io::AsyncReadExt; +use crate::{ + auth_passthrough::AuthPassthrough, + config::get_config, + messages::{ + error_response, md5_hash_password, md5_hash_second_pass, write_all, wrong_password, + }, + pool::{get_pool, ConnectionPool}, +}; +use bytes::{BufMut, BytesMut}; +use log::debug; + +async fn refetch_auth_hash( + pool: &ConnectionPool, + stream: &mut S, + username: &str, + pool_name: &str, +) -> Result +where + S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, +{ + let address = pool.address(0, 0); + if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { + let hash = apt.fetch_hash(address).await?; + + return Ok(hash); + } + + error_response( + stream, + &format!( + "No password set and auth passthrough failed for database: {}, user: {}", + pool_name, username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", + address.username, address.database + ))) +} + +/// Read 'p' message from client. +async fn response(stream: &mut R) -> Result, Error> +where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, +{ + let code = match stream.read_u8().await { + Ok(code) => code, + Err(_) => { + return Err(Error::SocketError( + "Error reading password code from client".to_string(), + )) + } + }; + + if code as char != 'p' { + return Err(Error::SocketError(format!("Expected p, got {}", code))); + } + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => { + return Err(Error::SocketError( + "Error reading password length from client".to_string(), + )) + } + }; + + let mut response = vec![0; (len - 4) as usize]; + + match stream.read_exact(&mut response).await { + Ok(_) => (), + Err(_) => { + return Err(Error::SocketError( + "Error reading password from client".to_string(), + )) + } + }; + + Ok(response.to_vec()) +} + +/// Make sure the pool we authenticated to has at least one server connection +/// that can serve our request. +async fn validate_pool( + stream: &mut W, + mut pool: ConnectionPool, + username: &str, + pool_name: &str, +) -> Result<(), Error> +where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, +{ + if !pool.validated() { + match pool.validate().await { + Ok(_) => Ok(()), + Err(err) => { + error_response( + stream, + &format!( + "Pool down for database: {:?}, user: {:?}", + pool_name, username, + ), + ) + .await?; + + Err(Error::ClientError(format!("Pool down: {:?}", err))) + } + } + } else { + Ok(()) + } +} + +/// Clear text authentication. +/// +/// The client will send the password in plain text over the wire. +/// To protect against obvious security issues, this is only used over TLS. +/// +/// Clear text authentication is used to support zero-downtime password rotation. +/// It allows the client to use multiple passwords when talking to the PgCat +/// while the password is being rotated across multiple app instances. +pub struct ClearText { + username: String, + pool_name: String, + application_name: String, +} + +impl ClearText { + /// Create a new ClearText authentication mechanism. + pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText { + ClearText { + username: username.to_string(), + pool_name: pool_name.to_string(), + application_name: application_name.to_string(), + } + } + + /// Issue 'R' clear text challenge to client. + pub async fn challenge(&self, stream: &mut W) -> Result<(), Error> + where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + debug!("Sending plain challenge"); + + let mut msg = BytesMut::new(); + msg.put_u8(b'R'); + msg.put_i32(8); + msg.put_i32(3); // Clear text + + write_all(stream, msg).await + } + + /// Authenticate client with server password or secret. + pub async fn authenticate( + &self, + read: &mut R, + write: &mut W, + ) -> Result, Error> + where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let response = response(read).await?; + + let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string(); + + match get_pool(&self.pool_name, &self.username, Some(secret.clone())) { + None => match get_pool(&self.pool_name, &self.username, None) { + Some(pool) => { + match pool.settings.user.password { + Some(ref password) => { + if password != &secret { + wrong_password(write, &self.username).await?; + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + + Ok(None) + } + } + + None => { + // Server is storing hashes, we can't query it for the plain text password. + error_response( + write, + &format!( + "No server password configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "No server password configured for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + } + } + + None => { + error_response( + write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + }, + Some(pool) => { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + Ok(Some(secret)) + } + } + } +} + +/// MD5 hash authentication. +/// +/// Deprecated, but widely used everywhere, and currently required for poolers +/// to authencticate clients without involving Postgres. +/// +/// Admin clients are required to use MD5. +pub struct Md5 { + username: String, + pool_name: String, + application_name: String, + salt: [u8; 4], + admin: bool, +} + +impl Md5 { + pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 { + let salt: [u8; 4] = [ + rand::random(), + rand::random(), + rand::random(), + rand::random(), + ]; + + Md5 { + username: username.to_string(), + pool_name: pool_name.to_string(), + application_name: application_name.to_string(), + salt, + admin, + } + } + + /// Issue a 'R' MD5 challenge to the client. + pub async fn challenge(&self, stream: &mut W) -> Result<(), Error> + where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let mut res = BytesMut::new(); + res.put_u8(b'R'); + res.put_i32(12); + res.put_i32(5); // MD5 + res.put_slice(&self.salt[..]); + + write_all(stream, res).await + } + + /// Authenticate client with MD5. This is used for both admin and normal users. + pub async fn authenticate(&self, read: &mut R, write: &mut W) -> Result<(), Error> + where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let password_hash = response(read).await?; + + if self.admin { + let config = get_config(); + + // Compare server and client hashes. + let our_hash = md5_hash_password( + &config.general.admin_username, + &config.general.admin_password, + &self.salt, + ); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + Ok(()) + } + } else { + match get_pool(&self.pool_name, &self.username, None) { + Some(pool) => { + match &pool.settings.user.password { + Some(ref password) => { + let our_hash = md5_hash_password(&self.username, password, &self.salt); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + Ok(()) + } + } + + None => { + // Fetch hash from server + let hash = (*pool.auth_hash.read()).clone(); + + let hash = match hash { + Some(hash) => hash.to_string(), + None => { + refetch_auth_hash(&pool, write, &self.username, &self.pool_name) + .await? + } + }; + + let our_hash = md5_hash_second_pass(&hash, &self.salt); + + // Compare hashes + if our_hash != password_hash { + // Server hash maybe changed + let hash = refetch_auth_hash( + &pool, + write, + &self.username, + &self.pool_name, + ) + .await?; + let our_hash = md5_hash_second_pass(&hash, &self.salt); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + (*pool.auth_hash.write()) = Some(hash); + + validate_pool( + write, + pool.clone(), + &self.username, + &self.pool_name, + ) + .await?; + + Ok(()) + } + } else { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + } + } + } + + None => { + error_response( + write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + return Err(Error::ClientError(format!( + "Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))); + } + } + } + } +} diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index b9f0e97f..33a9fb17 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -73,6 +73,7 @@ impl AuthPassthrough { password: Some(self.password.clone()), pool_size: 1, statement_timeout: 0, + secrets: None, }; let user = &address.username; diff --git a/src/client.rs b/src/client.rs index d75c069d..5ef5fe40 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ use crate::errors::Error; use crate::pool::BanReason; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; -use log::{debug, error, info, trace, warn}; +use log::{debug, error, info, trace}; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; @@ -90,6 +90,9 @@ pub struct Client { /// Application name for this client (defaults to pgcat) application_name: String, + /// Which secret the user is using to connect, if any. + secret: Option, + /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, } @@ -290,7 +293,7 @@ pub async fn client_entrypoint( /// Handle the first message the client sends. async fn get_startup(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> where - S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, + S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send, { // Get startup message length. let len = match stream.read_i32().await { @@ -377,24 +380,10 @@ pub async fn startup_tls( } } -async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { - let address = pool.address(0, 0); - if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { - let hash = apt.fetch_hash(address).await?; - - return Ok(hash); - } - - Err(Error::ClientError(format!( - "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", - address.username, address.database - ))) -} - impl Client where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, + S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, { pub fn is_admin(&self) -> bool { self.admin @@ -457,161 +446,39 @@ where let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); - // Perform MD5 authentication. - // TODO: Add SASL support. - let salt = md5_challenge(&mut write).await?; - - let code = match read.read_u8().await { - Ok(p) => p, - Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), - }; - - // PasswordMessage - if code as char != 'p' { - return Err(Error::ProtocolSyncError(format!( - "Expected p, got {}", - code as char - ))); - } - - let len = match read.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), - }; - - let mut password_response = vec![0u8; (len - 4) as usize]; - - match read.read_exact(&mut password_response).await { - Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), + let config = get_config(); + + let secret = if admin { + debug!("Using md5 auth for admin"); + let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await?; + None + } else if !config.tls_enabled() { + debug!("Using md5 auth"); + let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await?; + None + } else { + debug!("Using plain auth"); + let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await? }; - // Authenticate admin user. + // Authenticated admin user. let (transaction_mode, server_info) = if admin { - let config = get_config(); - // Compare server and client hashes. - let password_hash = md5_hash_password( - &config.general.admin_username, - &config.general.admin_password, - &salt, - ); - - if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); - wrong_password(&mut write, username).await?; - - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - (false, generate_server_info_for_admin()) } - // Authenticate normal user. + // Authenticated normal user. else { - let mut pool = match get_pool(pool_name, username) { - Some(pool) => pool, - None => { - error_response( - &mut write, - &format!( - "No pool configured for database: {:?}, user: {:?}", - pool_name, username - ), - ) - .await?; - - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - }; - - // Obtain the hash to compare, we give preference to that written in cleartext in config - // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained - // when the pool was created. If there is no hash there, we try to fetch it one more time. - let password_hash = if let Some(password) = &pool.settings.user.password { - Some(md5_hash_password(username, password, &salt)) - } else { - if !get_config().is_auth_query_configured() { - return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); - } - - let mut hash = (*pool.auth_hash.read()).clone(); - - if hash.is_none() { - warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); - match refetch_auth_hash(&pool).await { - Ok(fetched_hash) => { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash.clone()); - } - - hash = Some(fetched_hash); - } - Err(err) => { - return Err( - Error::ClientError( - format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", - username, - pool_name, - application_name, - err) - ) - ); - } - } - }; - - Some(md5_hash_second_pass(&hash.unwrap(), &salt)) - }; - - // Once we have the resulting hash, we compare with what the client gave us. - // If they do not match and auth query is set up, we try to refetch the hash one more time - // to see if the password has changed since the pool was created. - // - // @TODO: we could end up fetching again the same password twice (see above). - if password_hash.unwrap() != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); - let fetched_hash = refetch_auth_hash(&pool).await?; - let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); - - // Ok password changed in server an auth is possible. - if new_password_hash == password_response { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash); - } - } else { - wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - } - + let pool = get_pool(&pool_name, &username, secret.clone()).unwrap(); let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; - - // If the pool hasn't been validated yet, - // connect to the servers and figure out what's what. - if !pool.validated() { - match pool.validate().await { - Ok(_) => (), - Err(err) => { - error_response( - &mut write, - &format!( - "Pool down for database: {:?}, user: {:?}", - pool_name, username - ), - ) - .await?; - return Err(Error::ClientError(format!("Pool down: {:?}", err))); - } - } - } - (transaction_mode, pool.server_info()) }; - debug!("Password authentication successful"); + debug!("Authentication successful"); auth_ok(&mut write).await?; write_all(&mut write, server_info).await?; @@ -619,7 +486,7 @@ where ready_for_query(&mut write).await?; trace!("Startup OK"); - let pool_stats = match get_pool(pool_name, username) { + let pool_stats = match get_pool(pool_name, username, secret.clone()) { Some(pool) => { if !admin { pool.stats @@ -659,6 +526,7 @@ where application_name: application_name.to_string(), shutdown, connected_to_server: false, + secret, }) } @@ -693,6 +561,7 @@ where application_name: String::from("undefined"), shutdown, connected_to_server: false, + secret: None, }) } @@ -1200,7 +1069,7 @@ where /// Retrieve connection pool, if it exists. /// Return an error to the client otherwise. async fn get_pool(&mut self) -> Result { - match get_pool(&self.pool_name, &self.username) { + match get_pool(&self.pool_name, &self.username, self.secret.clone()) { Some(pool) => Ok(pool), None => { error_response( diff --git a/src/config.rs b/src/config.rs index 6545457c..e10b854d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ /// Parse the configuration file. use arc_swap::ArcSwap; -use log::{error, info}; +use log::{error, info, warn}; use once_cell::sync::Lazy; use regex::Regex; use serde_derive::{Deserialize, Serialize}; @@ -181,6 +181,26 @@ pub struct User { pub pool_size: u32, #[serde(default)] // 0 pub statement_timeout: u64, + pub secrets: Option>, +} + +impl User { + fn validate(&self) -> Result<(), Error> { + match self.secrets { + Some(ref secrets) => { + for secret in secrets.iter() { + if secret.len() < 16 { + warn!( + "[user: {}] Secret is too short (less than 16 characters)", + self.username + ); + } + } + } + None => (), + } + Ok(()) + } } impl Default for User { @@ -190,6 +210,7 @@ impl Default for User { password: None, pool_size: 15, statement_timeout: 0, + secrets: None, } } } @@ -508,6 +529,10 @@ impl Pool { None => None, }; + for user in self.users.iter() { + user.1.validate()?; + } + Ok(()) } } @@ -657,6 +682,11 @@ impl Config { } } } + + /// Checks that we configured TLS. + pub fn tls_enabled(&self) -> bool { + self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some() + } } impl Default for Config { diff --git a/src/main.rs b/src/main.rs index 4c8987f1..094fa0d1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,7 @@ use std::sync::Arc; use tokio::sync::broadcast; mod admin; +mod auth; mod auth_passthrough; mod client; mod config; diff --git a/src/messages.rs b/src/messages.rs index 61c36c6d..c9a2e4ab 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -46,29 +46,6 @@ where write_all(stream, auth_ok).await } -/// Generate md5 password challenge. -pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> -where - S: tokio::io::AsyncWrite + std::marker::Unpin, -{ - // let mut rng = rand::thread_rng(); - let salt: [u8; 4] = [ - rand::random(), - rand::random(), - rand::random(), - rand::random(), - ]; - - let mut res = BytesMut::new(); - res.put_u8(b'R'); - res.put_i32(12); - res.put_i32(5); // MD5 - res.put_slice(&salt[..]); - - write_all(stream, res).await?; - Ok(salt) -} - /// Give the client the process_id and secret we generated /// used in query cancellation. pub async fn backend_key_data( diff --git a/src/mirrors.rs b/src/mirrors.rs index 17f91d4d..e8918f5f 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -34,7 +34,7 @@ impl MirroredClient { None => (default, default, crate::config::Pool::default()), }; - let identifier = PoolIdentifier::new(&self.database, &self.user.username); + let identifier = PoolIdentifier::new(&self.database, &self.user.username, None); let manager = ServerPool::new( self.address.clone(), diff --git a/src/pool.rs b/src/pool.rs index e1ab7cb4..8d44efdc 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -59,24 +59,22 @@ pub struct PoolIdentifier { /// The username the client connects with. Each user gets its own pool. pub user: String, + + /// The client secret (password). + pub secret: Option, } impl PoolIdentifier { /// Create a new user/pool identifier. - pub fn new(db: &str, user: &str) -> PoolIdentifier { + pub fn new(db: &str, user: &str, secret: Option) -> PoolIdentifier { PoolIdentifier { db: db.to_string(), user: user.to_string(), + secret, } } } -impl From<&Address> for PoolIdentifier { - fn from(address: &Address) -> PoolIdentifier { - PoolIdentifier::new(&address.database, &address.username) - } -} - /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { @@ -210,224 +208,241 @@ impl ConnectionPool { // There is one pool per database/user pair. for user in pool_config.users.values() { - let old_pool_ref = get_pool(pool_name, &user.username); - let identifier = PoolIdentifier::new(pool_name, &user.username); - - match old_pool_ref { - Some(pool) => { - // If the pool hasn't changed, get existing reference and insert it into the new_pools. - // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). - if pool.config_hash == new_pool_hash_value { - info!( - "[pool: {}][user: {}] has not changed", - pool_name, user.username - ); - new_pools.insert(identifier.clone(), pool.clone()); - continue; + let mut secrets = match &user.secrets { + Some(_) => user + .secrets + .as_ref() + .unwrap() + .iter() + .map(|secret| Some(secret.to_string())) + .collect::>>(), + None => vec![], + }; + + secrets.push(None); + + for secret in secrets { + let old_pool_ref = get_pool(pool_name, &user.username, secret.clone()); + let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone()); + + match old_pool_ref { + Some(pool) => { + // If the pool hasn't changed, get existing reference and insert it into the new_pools. + // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). + if pool.config_hash == new_pool_hash_value { + info!( + "[pool: {}][user: {}] has not changed", + pool_name, user.username + ); + new_pools.insert(identifier.clone(), pool.clone()); + continue; + } } + None => (), } - None => (), - } - info!( - "[pool: {}][user: {}] creating new pool", - pool_name, user.username - ); + info!( + "[pool: {}][user: {}] creating new pool", + pool_name, user.username + ); - let mut shards = Vec::new(); - let mut addresses = Vec::new(); - let mut banlist = Vec::new(); - let mut shard_ids = pool_config - .shards - .clone() - .into_keys() - .collect::>(); - let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone())); - - // Allow the pool to be seen in statistics - pool_stats.register(pool_stats.clone()); - - // Sort by shard number to ensure consistency. - shard_ids.sort_by_key(|k| k.parse::().unwrap()); - let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); - - for shard_idx in &shard_ids { - let shard = &pool_config.shards[shard_idx]; - let mut pools = Vec::new(); - let mut servers = Vec::new(); - let mut replica_number = 0; - - // Load Mirror settings - for (address_index, server) in shard.servers.iter().enumerate() { - let mut mirror_addresses = vec![]; - if let Some(mirror_settings_vec) = &shard.mirrors { - for (mirror_idx, mirror_settings) in - mirror_settings_vec.iter().enumerate() - { - if mirror_settings.mirroring_target_index != address_index { - continue; + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = pool_config + .shards + .clone() + .into_keys() + .collect::>(); + let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone())); + + // Allow the pool to be seen in statistics + pool_stats.register(pool_stats.clone()); + + // Sort by shard number to ensure consistency. + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); + + for shard_idx in &shard_ids { + let shard = &pool_config.shards[shard_idx]; + let mut pools = Vec::new(); + let mut servers = Vec::new(); + let mut replica_number = 0; + + // Load Mirror settings + for (address_index, server) in shard.servers.iter().enumerate() { + let mut mirror_addresses = vec![]; + if let Some(mirror_settings_vec) = &shard.mirrors { + for (mirror_idx, mirror_settings) in + mirror_settings_vec.iter().enumerate() + { + if mirror_settings.mirroring_target_index != address_index { + continue; + } + mirror_addresses.push(Address { + id: address_id, + database: shard.database.clone(), + host: mirror_settings.host.clone(), + port: mirror_settings.port, + role: server.role, + address_index: mirror_idx, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: vec![], + stats: Arc::new(AddressStats::default()), + }); + address_id += 1; } - mirror_addresses.push(Address { - id: address_id, - database: shard.database.clone(), - host: mirror_settings.host.clone(), - port: mirror_settings.port, - role: server.role, - address_index: mirror_idx, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: vec![], - stats: Arc::new(AddressStats::default()), - }); - address_id += 1; } - } - let address = Address { - id: address_id, - database: shard.database.clone(), - host: server.host.clone(), - port: server.port, - role: server.role, - address_index, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: mirror_addresses, - stats: Arc::new(AddressStats::default()), - }; - - address_id += 1; - - if server.role == Role::Replica { - replica_number += 1; - } + let address = Address { + id: address_id, + database: shard.database.clone(), + host: server.host.clone(), + port: server.port, + role: server.role, + address_index, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: mirror_addresses, + stats: Arc::new(AddressStats::default()), + }; + + address_id += 1; + + if server.role == Role::Replica { + replica_number += 1; + } - // We assume every server in the pool share user/passwords - let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); - - if let Some(apt) = &auth_passthrough { - match apt.fetch_hash(&address).await { - Ok(ok) => { - if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { - if ok != *pool_auth_hash_value { - warn!("Hash is not the same across shards of the same pool, client auth will \ - be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); - } - } - debug!("Hash obtained for {:?}", address); - { - let mut pool_auth_hash = pool_auth_hash.write(); - *pool_auth_hash = Some(ok.clone()); - } - }, - Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), - } - } + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { + if ok != *pool_auth_hash_value { + warn!("Hash is not the same across shards of the same pool, client auth will \ + be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); + } + } + debug!("Hash obtained for {:?}", address); + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } + }, + Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), + } + } - let manager = ServerPool::new( - address.clone(), - user.clone(), - &shard.database, - client_server_map.clone(), - pool_stats.clone(), - pool_auth_hash.clone(), - ); + let manager = ServerPool::new( + address.clone(), + user.clone(), + &shard.database, + client_server_map.clone(), + pool_stats.clone(), + pool_auth_hash.clone(), + ); - let connect_timeout = match pool_config.connect_timeout { - Some(connect_timeout) => connect_timeout, - None => config.general.connect_timeout, - }; - - let idle_timeout = match pool_config.idle_timeout { - Some(idle_timeout) => idle_timeout, - None => config.general.idle_timeout, - }; - - let pool = Pool::builder() - .max_size(user.pool_size) - .connection_timeout(std::time::Duration::from_millis(connect_timeout)) - .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) - .test_on_check_out(false) - .build(manager) - .await - .unwrap(); - - pools.push(pool); - servers.push(address); - } + let connect_timeout = match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => config.general.connect_timeout, + }; + + let idle_timeout = match pool_config.idle_timeout { + Some(idle_timeout) => idle_timeout, + None => config.general.idle_timeout, + }; + + let pool = Pool::builder() + .max_size(user.pool_size) + .connection_timeout(std::time::Duration::from_millis( + connect_timeout, + )) + .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) + .test_on_check_out(false) + .build(manager) + .await + .unwrap(); + + pools.push(pool); + servers.push(address); + } - shards.push(pools); - addresses.push(servers); - banlist.push(HashMap::new()); - } + shards.push(pools); + addresses.push(servers); + banlist.push(HashMap::new()); + } - assert_eq!(shards.len(), addresses.len()); - if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { - info!( - "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", - pool_name, user.username - ); - } + assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username + ); + } - let pool = ConnectionPool { - databases: shards, - stats: pool_stats, - addresses, - banlist: Arc::new(RwLock::new(banlist)), - config_hash: new_pool_hash_value, - server_info: Arc::new(RwLock::new(BytesMut::new())), - auth_hash: pool_auth_hash, - settings: PoolSettings { - pool_mode: pool_config.pool_mode, - load_balancing_mode: pool_config.load_balancing_mode, - // shards: pool_config.shards.clone(), - shards: shard_ids.len(), - user: user.clone(), - default_role: match pool_config.default_role.as_str() { - "any" => None, - "replica" => Some(Role::Replica), - "primary" => Some(Role::Primary), - _ => unreachable!(), + let pool = ConnectionPool { + databases: shards, + stats: pool_stats, + addresses, + banlist: Arc::new(RwLock::new(banlist)), + config_hash: new_pool_hash_value, + server_info: Arc::new(RwLock::new(BytesMut::new())), + auth_hash: pool_auth_hash, + settings: PoolSettings { + pool_mode: pool_config.pool_mode, + load_balancing_mode: pool_config.load_balancing_mode, + // shards: pool_config.shards.clone(), + shards: shard_ids.len(), + user: user.clone(), + default_role: match pool_config.default_role.as_str() { + "any" => None, + "replica" => Some(Role::Replica), + "primary" => Some(Role::Primary), + _ => unreachable!(), + }, + query_parser_enabled: pool_config.query_parser_enabled, + primary_reads_enabled: pool_config.primary_reads_enabled, + sharding_function: pool_config.sharding_function, + automatic_sharding_key: pool_config.automatic_sharding_key.clone(), + healthcheck_delay: config.general.healthcheck_delay, + healthcheck_timeout: config.general.healthcheck_timeout, + ban_time: config.general.ban_time, + sharding_key_regex: pool_config + .sharding_key_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + shard_id_regex: pool_config + .shard_id_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), }, - query_parser_enabled: pool_config.query_parser_enabled, - primary_reads_enabled: pool_config.primary_reads_enabled, - sharding_function: pool_config.sharding_function, - automatic_sharding_key: pool_config.automatic_sharding_key.clone(), - healthcheck_delay: config.general.healthcheck_delay, - healthcheck_timeout: config.general.healthcheck_timeout, - ban_time: config.general.ban_time, - sharding_key_regex: pool_config - .sharding_key_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - shard_id_regex: pool_config - .shard_id_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), - auth_query: pool_config.auth_query.clone(), - auth_query_user: pool_config.auth_query_user.clone(), - auth_query_password: pool_config.auth_query_password.clone(), - }, - validated: Arc::new(AtomicBool::new(false)), - paused: Arc::new(AtomicBool::new(false)), - paused_waiter: Arc::new(Notify::new()), - }; + validated: Arc::new(AtomicBool::new(false)), + paused: Arc::new(AtomicBool::new(false)), + paused_waiter: Arc::new(Notify::new()), + }; - // Connect to the servers to make sure pool configuration is valid - // before setting it globally. - // Do this async and somewhere else, we don't have to wait here. - let mut validate_pool = pool.clone(); - tokio::task::spawn(async move { - let _ = validate_pool.validate().await; - }); + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + // Do this async and somewhere else, we don't have to wait here. + let mut validate_pool = pool.clone(); + tokio::task::spawn(async move { + let _ = validate_pool.validate().await; + }); - // There is one pool per database/user pair. - new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); + // There is one pool per database/user pair. + new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool); + } } } @@ -924,10 +939,10 @@ impl ManageConnection for ServerPool { } /// Get the connection pool -pub fn get_pool(db: &str, user: &str) -> Option { - (*(*POOLS.load())) - .get(&PoolIdentifier::new(db, user)) - .cloned() +pub fn get_pool(db: &str, user: &str, secret: Option) -> Option { + let identifier = PoolIdentifier::new(db, user, secret); + + (*(*POOLS.load())).get(&identifier).cloned() } /// Get a pointer to all configured pools. diff --git a/src/prometheus.rs b/src/prometheus.rs index 6e578bf0..67517900 100644 --- a/src/prometheus.rs +++ b/src/prometheus.rs @@ -9,7 +9,7 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use crate::config::Address; -use crate::pool::get_all_pools; +use crate::pool::{get_all_pools, PoolIdentifier}; use crate::stats::{get_pool_stats, get_server_stats, ServerStats}; struct MetricHelpType { @@ -233,10 +233,10 @@ impl PrometheusMetric { Self::from_name(&format!("stats_{}", name), value, labels) } - fn from_pool(pool: &(String, String), name: &str, value: u64) -> Option> { + fn from_pool(pool: &PoolIdentifier, name: &str, value: u64) -> Option> { let mut labels = HashMap::new(); - labels.insert("pool", pool.0.clone()); - labels.insert("user", pool.1.clone()); + labels.insert("pool", pool.db.clone()); + labels.insert("user", pool.user.clone()); Self::from_name(&format!("pools_{}", name), value, labels) } @@ -294,7 +294,7 @@ fn push_pool_stats(lines: &mut Vec) { } else { warn!( "Metric {} not implemented for ({},{})", - name, pool.0, pool.1 + name, pool.db, pool.user ); } } diff --git a/src/stats.rs b/src/stats.rs index 5b7895b4..63488def 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -22,7 +22,7 @@ pub use server::{ServerState, ServerStats}; /// Convenience types for various stats type ClientStatesLookup = HashMap>; type ServerStatesLookup = HashMap>; -type PoolStatsLookup = HashMap<(String, String), Arc>; +type PoolStatsLookup = HashMap>; /// Stats for individual client connections /// Used in SHOW CLIENTS. @@ -83,9 +83,7 @@ impl Reporter { /// Register a pool with the stats system. fn pool_register(&self, identifier: PoolIdentifier, stats: Arc) { - POOL_STATS - .write() - .insert((identifier.db, identifier.user), stats); + POOL_STATS.write().insert(identifier, stats); } } diff --git a/src/stats/pool.rs b/src/stats/pool.rs index 1b01ef2e..51ec78fd 100644 --- a/src/stats/pool.rs +++ b/src/stats/pool.rs @@ -102,6 +102,13 @@ impl PoolStats { self.identifier.user.clone() } + pub fn redacted_secret(&self) -> String { + match self.identifier.secret { + Some(ref s) => format!("****{}", &s[s.len() - 4..]), + None => "".to_string(), + } + } + pub fn pool_mode(&self) -> PoolMode { self.config.pool_mode } diff --git a/tests/ruby/auth_spec.rb b/tests/ruby/auth_spec.rb new file mode 100644 index 00000000..acf458ba --- /dev/null +++ b/tests/ruby/auth_spec.rb @@ -0,0 +1,39 @@ +# frozen_string_literal: true +require_relative 'spec_helper' + + +describe "Authentication" do + describe "multiple secrets configured" do + let(:secrets) { ["one_secret", "two_secret"] } + let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info", secrets=["one_secret", "two_secret"]) } + + after do + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + + it "can connect using all secrets and postgres password" do + secrets.push("sharding_user").each do |secret| + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password=secret)) + conn.exec("SELECT current_user") + end + end + end + + describe "no secrets configured" do + let(:secrets) { [] } + let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info") } + + after do + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + + it "can connect using only the password" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.exec("SELECT current_user") + + expect { PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password="secret_one")) }.to raise_error PG::ConnectionBad + end + end +end diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index 13dc6686..8395d09a 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -12,14 +12,18 @@ def deep_merge(second) module Helpers module Pgcat - def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info") + def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", secrets=nil) user = { "password" => "sharding_user", "pool_size" => pool_size, "statement_timeout" => 0, - "username" => "sharding_user" + "username" => "sharding_user", } + if !secrets.nil? + user["secrets"] = secrets + end + pgcat = PgcatProcess.new(log_level) primary0 = PgInstance.new(5432, user["username"], user["password"], "shard0") primary1 = PgInstance.new(7432, user["username"], user["password"], "shard1") @@ -27,7 +31,7 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod pgcat_cfg = pgcat.current_config pgcat_cfg["pools"] = { - "#{pool_name}" => { + "#{pool_name}" => { "default_role" => "any", "pool_mode" => pool_mode, "load_balancing_mode" => lb_mode, @@ -41,8 +45,14 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] }, }, "users" => { "0" => user } - } + }, } + + if !secrets.nil? + pgcat_cfg["general"]["tls_certificate"] = "../../.circleci/server.cert" + pgcat_cfg["general"]["tls_private_key"] = "../../.circleci/server.key" + end + pgcat.update_config(pgcat_cfg) pgcat.start diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index e1dbea8b..5489bc90 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -112,10 +112,13 @@ def admin_connection_string "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" end - def connection_string(pool_name, username, password = nil) + def connection_string(pool_name, username, password=nil) cfg = current_config user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } - "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" + + password = if password.nil? then user_obj["password"] else password end + + "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}" end def example_connection_string