Skip to content

Implement Decode, Encode and Type for Box, Arc, Cow and Rc #3674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Relax Sized bound for Decode, Encode
  • Loading branch information
joeydewaal committed Apr 1, 2025
commit 6669bd3adf0529dee19b598758cd10a347caaef8
61 changes: 53 additions & 8 deletions sqlx-core/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ where
}

macro_rules! impl_decode_for_smartpointer {
($smart_pointer:ty) => {
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer
($smart_pointer:tt) => {
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer<T>
where
DB: Database,
T: Decode<'r, DB>,
Expand All @@ -93,21 +93,66 @@ macro_rules! impl_decode_for_smartpointer {
Ok(Self::new(T::decode(value)?))
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<str>
where
DB: Database,
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let ref_str = <&str as Decode<DB>>::decode(value)?;
Ok(ref_str.into())
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]>
where
DB: Database,
&'r [u8]: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let ref_str = <&[u8] as Decode<DB>>::decode(value)?;
Ok(ref_str.into())
}
}
};
}

impl_decode_for_smartpointer!(Arc<T>);
impl_decode_for_smartpointer!(Box<T>);
impl_decode_for_smartpointer!(Rc<T>);
impl_decode_for_smartpointer!(Arc);
impl_decode_for_smartpointer!(Box);
impl_decode_for_smartpointer!(Rc);

// implement `Decode` for Cow<T> for all SQL types
impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T>
where
DB: Database,
T: Decode<'r, DB>,
T: ToOwned<Owned = T>,
T: ToOwned,
<T as ToOwned>::Owned: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let owned = <<T as ToOwned>::Owned as Decode<DB>>::decode(value)?;
Ok(Cow::Owned(owned))
}
}

impl<'r, DB> Decode<'r, DB> for Cow<'r, str>
where
DB: Database,
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let borrowed = <&str as Decode<DB>>::decode(value)?;
Ok(Cow::Borrowed(borrowed))
}
}

impl<'r, DB> Decode<'r, DB> for Cow<'r, [u8]>
where
DB: Database,
&'r [u8]: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Cow::Owned(T::decode(value)?))
let borrowed = <&[u8] as Decode<DB>>::decode(value)?;
Ok(Cow::Borrowed(borrowed))
}
}
10 changes: 5 additions & 5 deletions sqlx-core/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,28 @@ impl_encode_for_smartpointer!(Rc<T>);
impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T>
where
T: Encode<'q, DB>,
T: ToOwned<Owned = T>,
T: ToOwned,
{
#[inline]
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> Result<IsNull, BoxDynError> {
<T as Encode<DB>>::encode_by_ref(self.as_ref(), buf)
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode(self, buf)
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn produces(&self) -> Option<DB::TypeInfo> {
(**self).produces()
<&T as Encode<DB>>::produces(&self.as_ref())
}

#[inline]
fn size_hint(&self) -> usize {
(**self).size_hint()
<&T as Encode<DB>>::size_hint(&self.as_ref())
}
}
2 changes: 1 addition & 1 deletion sqlx-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl_type_for_smartpointer!(Rc<T>);
impl<T, DB: Database> Type<DB> for Cow<'_, T>
where
T: Type<DB>,
T: ToOwned<Owned = T>,
T: ToOwned,
T: ?Sized,
{
fn type_info() -> DB::TypeInfo {
Expand Down
6 changes: 0 additions & 6 deletions sqlx-mysql/src/types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ impl Encode<'_, MySql> for Box<[u8]> {
}
}

impl<'r> Decode<'r, MySql> for Box<[u8]> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for Vec<u8> {
fn type_info() -> MySqlTypeInfo {
<[u8] as Type<MySql>>::type_info()
Expand Down
15 changes: 5 additions & 10 deletions sqlx-mysql/src/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::io::MySqlBufMutExt;
use crate::protocol::text::{ColumnFlags, ColumnType};
use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use std::borrow::Cow;

impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
Expand Down Expand Up @@ -52,12 +53,6 @@ impl Encode<'_, MySql> for Box<str> {
}
}

impl<'r> Decode<'r, MySql> for Box<str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for String {
fn type_info() -> MySqlTypeInfo {
<str as Type<MySql>>::type_info()
Expand Down Expand Up @@ -89,8 +84,8 @@ impl Encode<'_, MySql> for Cow<'_, str> {
}
}

impl<'r> Decode<'r, MySql> for Cow<'r, str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str().map(Cow::Borrowed)
impl Encode<'_, MySql> for Cow<'_, [u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(self.as_ref(), buf)
}
}
9 changes: 0 additions & 9 deletions sqlx-postgres/src/types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,6 @@ fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> {
.map_err(Into::into)
}

impl Decode<'_, Postgres> for Box<[u8]> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => Box::from(value.as_bytes()?),
PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?),
})
}
}

impl Decode<'_, Postgres> for Vec<u8> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
Expand Down
30 changes: 12 additions & 18 deletions sqlx-postgres/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,6 @@ impl Encode<'_, Postgres> for &'_ str {
}
}

impl Encode<'_, Postgres> for Cow<'_, str> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
match self {
Cow::Borrowed(str) => <&str as Encode<Postgres>>::encode(*str, buf),
Cow::Owned(str) => <&str as Encode<Postgres>>::encode(&**str, buf),
}
}
}

impl Encode<'_, Postgres> for Box<str> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&str as Encode<Postgres>>::encode(&**self, buf)
Expand All @@ -109,20 +100,23 @@ impl<'r> Decode<'r, Postgres> for &'r str {
}
}

impl<'r> Decode<'r, Postgres> for Cow<'r, str> {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Cow::Borrowed(value.as_str()?))
impl Decode<'_, Postgres> for String {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.to_owned())
}
}

impl<'r> Decode<'r, Postgres> for Box<str> {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Box::from(value.as_str()?))
impl Encode<'_, Postgres> for Cow<'_, str> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
match self {
Cow::Borrowed(str) => <&str as Encode<Postgres>>::encode(*str, buf),
Cow::Owned(str) => <&str as Encode<Postgres>>::encode(&**str, buf),
}
}
}

impl Decode<'_, Postgres> for String {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.to_owned())
impl Encode<'_, Postgres> for Cow<'_, [u8]> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<Postgres>>::encode(self.as_ref(), buf)
}
}
6 changes: 0 additions & 6 deletions sqlx-sqlite/src/types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ impl Encode<'_, Sqlite> for Box<[u8]> {
}
}

impl Decode<'_, Sqlite> for Box<[u8]> {
fn decode(value: SqliteValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(Box::from(value.blob()))
}
}

impl Type<Sqlite> for Vec<u8> {
fn type_info() -> SqliteTypeInfo {
<&[u8] as Type<Sqlite>>::type_info()
Expand Down
23 changes: 14 additions & 9 deletions sqlx-sqlite/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ impl Encode<'_, Sqlite> for Box<str> {
}
}

impl Decode<'_, Sqlite> for Box<str> {
fn decode(value: SqliteValueRef<'_>) -> Result<Self, BoxDynError> {
value.text().map(Box::from)
}
}

impl Type<Sqlite> for String {
fn type_info() -> SqliteTypeInfo {
<&str as Type<Sqlite>>::type_info()
Expand Down Expand Up @@ -101,8 +95,19 @@ impl<'q> Encode<'q, Sqlite> for Cow<'q, str> {
}
}

impl<'r> Decode<'r, Sqlite> for Cow<'r, str> {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
value.text().map(Cow::Borrowed)
impl<'q> Encode<'q, Sqlite> for Cow<'q, [u8]> {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(self));

Ok(IsNull::No)
}

fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'q>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(self.clone()));

Ok(IsNull::No)
}
}
51 changes: 50 additions & 1 deletion tests/mysql/types.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
extern crate time_ as time;

use std::borrow::Cow;
use std::net::SocketAddr;
use std::rc::Rc;
#[cfg(feature = "rust_decimal")]
use std::str::FromStr;
use std::sync::Arc;

use sqlx::mysql::MySql;
use sqlx::{Executor, Row};
use sqlx::{Executor, FromRow, Row};

use sqlx::types::Text;

Expand Down Expand Up @@ -384,3 +387,49 @@ CREATE TEMPORARY TABLE user_login (

Ok(())
}

#[sqlx_macros::test]
async fn test_smartpointers() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let user_age: (Arc<i32>, Cow<'static, i32>, Box<i32>, i32) =
sqlx::query_as("SELECT ?, ?, ?, ?")
.bind(Arc::new(1i32))
.bind(Cow::<'_, i32>::Borrowed(&2i32))
.bind(Box::new(3i32))
.bind(Rc::new(4i32))
.fetch_one(&mut conn)
.await?;

assert!(user_age.0.as_ref() == &1);
assert!(user_age.1.as_ref() == &2);
assert!(user_age.2.as_ref() == &3);
assert!(user_age.3 == 4);
Ok(())
}

#[sqlx_macros::test]
async fn test_str_slice() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let box_str: Box<str> = "John".into();
let box_slice: Box<[u8]> = [1, 2, 3, 4].into();
let cow_str: Cow<'static, str> = "Phil".into();
let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]);

let row = sqlx::query("SELECT ?, ?, ?, ?")
.bind(&box_str)
.bind(&box_slice)
.bind(&cow_str)
.bind(&cow_slice)
.fetch_one(&mut conn)
.await?;

let data: (Box<str>, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?;

assert!(data.0 == box_str);
assert!(data.1 == box_slice);
assert!(data.2 == cow_str);
assert!(data.3 == cow_slice);
Ok(())
}
Loading