Skip to content

Commit 9fc72c0

Browse files
committed
Add VecRef enum to set rust values to Oracle VECTOR data type
1 parent 6bc23c0 commit 9fc72c0

File tree

5 files changed

+300
-1
lines changed

5 files changed

+300
-1
lines changed

ChangeLog.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ New features:
88
* Add [`OracleType::Vector`] variant
99
* Add [`InnerValue::Vector`] variant
1010
* Add [`VecFmt`] enum type
11+
* Add [`VecRef`] enum type to set rust values to Oracle VECTOR data type
12+
* Add [`VectorFormat`] trait type
1113

1214
Incompatible changes:
1315

@@ -593,8 +595,10 @@ Incompatible changes:
593595
[`StatementBuilder::tag()`]: https://www.jiubao.org/rust-oracle/oracle/struct.StatementBuilder.html#method.tag
594596
[`StmtParam`]: https://docs.rs/oracle/0.5.*/oracle/enum.StmtParam.html
595597
[`StmtParam::FetchArraySize`]: https://docs.rs/oracle/0.5.*/oracle/enum.StmtParam.html#variant.FetchArraySize
596-
[`VecFmt`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/vector/enum.VecFmt.html
597598
[`Timestamp::and_prec()`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/struct.Timestamp.html#method.and_prec
598599
[`Timestamp::and_tz_hm_offset()`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/struct.Timestamp.html#method.and_tz_hm_offset
599600
[`Timestamp::and_tz_offset()`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/struct.Timestamp.html#method.and_tz_offset
600601
[`Timestamp::new()`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/struct.Timestamp.html#method.new
602+
[`VecFmt`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/vector/enum.VecFmt.html
603+
[`VecRef`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/vector/enum.VecRef.html
604+
[`VectorFormat`]: https://www.jiubao.org/rust-oracle/oracle/sql_type/vector/trait.VectorFormat.html

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,14 @@ mod private {
272272

273273
pub trait Sealed {}
274274

275+
impl Sealed for i8 {}
275276
impl Sealed for u8 {}
276277
impl Sealed for u16 {}
277278
impl Sealed for u32 {}
278279
impl Sealed for u64 {}
279280
impl Sealed for usize {}
281+
impl Sealed for f32 {}
282+
impl Sealed for f64 {}
280283
impl Sealed for bool {}
281284
impl Sealed for str {}
282285
impl Sealed for [u8] {}

src/sql_type/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
//! SQL data types
1717
18+
#[cfg(doc)]
19+
use crate::sql_type::vector::VecRef;
1820
use crate::Connection;
1921
use crate::ErrorKind;
2022
use crate::Result;
@@ -122,6 +124,7 @@ pub trait FromSql {
122124
/// | [`IntervalDS`] | `interval day(9) to second(9)` |
123125
/// | [`IntervalYM`] | `interval year(9) to month` |
124126
/// | [`RefCursor`] | `ref cursor` |
127+
/// | [`VecRef`] | `vector` |
125128
///
126129
/// When `chrono` feature is enabled, the followings are added.
127130
///
@@ -157,6 +160,7 @@ pub trait ToSqlNull {
157160
/// | [`IntervalYM`] | `interval year(9) to month` | The specified value |
158161
/// | [`Collection`] | type returned by [`Collection::object_type`] | The specified value |
159162
/// | [`Object`] | type returned by [`Object::object_type`] | The specified value |
163+
/// | [`VecRef`] | `vector` |
160164
/// | `Option\<T>` where T: `ToSql` + [`ToSqlNull`] | When the value is `Some`, the contained value decides the Oracle type. When it is `None`, ToSqlNull decides it. | When the value is `Some`, the contained value. When it is `None`, a null value.
161165
/// | [`OracleType`] | type represented by the OracleType. | a null value |
162166
/// | `(&ToSql, &OracleType)` | type represented by the second element. | The value of the first element |

src/sql_type/vector.rs

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
// (ii) the Apache License v 2.0. (http://www.apache.org/licenses/LICENSE-2.0)
1414
//-----------------------------------------------------------------------------
1515

16+
use crate::private;
17+
use crate::sql_type::OracleType;
18+
use crate::sql_type::SqlValue;
19+
use crate::sql_type::ToSql;
20+
use crate::sql_type::ToSqlNull;
21+
use crate::Connection;
1622
use crate::Error;
23+
use crate::ErrorKind;
1724
use crate::Result;
1825
use odpic_sys::*;
1926
use std::fmt;
27+
use std::os::raw::c_void;
2028

2129
/// Vector dimension element format
2230
///
@@ -75,9 +83,230 @@ impl fmt::Display for VecFmt {
7583
}
7684
}
7785

86+
/// Reference to vector dimension elements
87+
///
88+
/// See the [module-level documentation](index.html) for more.
89+
#[derive(Clone, Debug, PartialEq)]
90+
#[non_exhaustive]
91+
pub enum VecRef<'a> {
92+
/// Wraps `[f32]` slice data as Oracle data type `VECTOR(FLOAT32)`
93+
Float32(&'a [f32]),
94+
/// Wraps `[f64]` slice data as Oracle data type `VECTOR(FLOAT64)`
95+
Float64(&'a [f64]),
96+
/// Wraps `[i8]` slice data as Oracle data type `VECTOR(INT8)`
97+
Int8(&'a [i8]),
98+
/// Wraps `[u8]` slice data as Oracle data type `VECTOR(BINARY)`
99+
Binary(&'a [u8]),
100+
}
101+
102+
impl VecRef<'_> {
103+
pub(crate) fn to_dpi(&self) -> Result<dpiVectorInfo> {
104+
match self {
105+
VecRef::Float32(slice) => Ok(dpiVectorInfo {
106+
format: DPI_VECTOR_FORMAT_FLOAT32 as u8,
107+
numDimensions: slice.len().try_into()?,
108+
dimensionSize: 4,
109+
dimensions: dpiVectorDimensionBuffer {
110+
asPtr: slice.as_ptr() as *mut c_void,
111+
},
112+
}),
113+
VecRef::Float64(slice) => Ok(dpiVectorInfo {
114+
format: DPI_VECTOR_FORMAT_FLOAT64 as u8,
115+
numDimensions: slice.len().try_into()?,
116+
dimensionSize: 8,
117+
dimensions: dpiVectorDimensionBuffer {
118+
asPtr: slice.as_ptr() as *mut c_void,
119+
},
120+
}),
121+
VecRef::Int8(slice) => Ok(dpiVectorInfo {
122+
format: DPI_VECTOR_FORMAT_INT8 as u8,
123+
numDimensions: slice.len().try_into()?,
124+
dimensionSize: 1,
125+
dimensions: dpiVectorDimensionBuffer {
126+
asPtr: slice.as_ptr() as *mut c_void,
127+
},
128+
}),
129+
VecRef::Binary(slice) => Ok(dpiVectorInfo {
130+
format: DPI_VECTOR_FORMAT_BINARY as u8,
131+
numDimensions: (slice.len() * 8).try_into()?,
132+
dimensionSize: 1,
133+
dimensions: dpiVectorDimensionBuffer {
134+
asPtr: slice.as_ptr() as *mut c_void,
135+
},
136+
}),
137+
}
138+
}
139+
140+
/// Returns vector dimension element format
141+
///
142+
/// # Examples
143+
///
144+
/// ```
145+
/// # use oracle::sql_type::vector::VecFmt;
146+
/// # use oracle::sql_type::vector::VecRef;
147+
/// // Refernce to float32 vector data.
148+
/// let vec_ref = VecRef::Float32(&[0.001, 3.0, 5.3]);
149+
///
150+
/// assert_eq!(vec_ref.format(), VecFmt::Float32);
151+
/// # Ok::<(), Box<dyn std::error::Error>>(())
152+
/// ```
153+
pub fn format(&self) -> VecFmt {
154+
match self {
155+
VecRef::Float32(_) => VecFmt::Float32,
156+
VecRef::Float64(_) => VecFmt::Float64,
157+
VecRef::Int8(_) => VecFmt::Int8,
158+
VecRef::Binary(_) => VecFmt::Binary,
159+
}
160+
}
161+
162+
/// Gets the containing vector data as slice
163+
///
164+
/// # Examples
165+
///
166+
/// ```
167+
/// # use oracle::sql_type::vector::VecRef;
168+
/// // Refernce to float32 vector data.
169+
/// let vec_ref = VecRef::Float32(&[0.001, 3.0, 5.3]);
170+
///
171+
/// // Gets as a slice of [f32]
172+
/// assert_eq!(vec_ref.as_slice::<f32>()?, &[0.001, 3.0, 5.3]);
173+
///
174+
/// // Errors for other vector dimension types.
175+
/// assert!(vec_ref.as_slice::<f64>().is_err());
176+
/// assert!(vec_ref.as_slice::<i8>().is_err());
177+
/// assert!(vec_ref.as_slice::<u8>().is_err());
178+
/// # Ok::<(), Box<dyn std::error::Error>>(())
179+
/// ```
180+
pub fn as_slice<T>(&self) -> Result<&[T]>
181+
where
182+
T: VectorFormat,
183+
{
184+
T::vec_ref_to_slice(self)
185+
}
186+
187+
fn oracle_type(&self) -> OracleType {
188+
match self {
189+
VecRef::Float32(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float32),
190+
VecRef::Float64(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float64),
191+
VecRef::Int8(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Int8),
192+
VecRef::Binary(slice) => OracleType::Vector(slice.len() as u32 * 8, VecFmt::Binary),
193+
}
194+
}
195+
}
196+
197+
impl<'a, T> From<&'a [T]> for VecRef<'a>
198+
where
199+
T: VectorFormat,
200+
{
201+
fn from(s: &'a [T]) -> VecRef<'a> {
202+
T::slice_to_vec_ref(s)
203+
}
204+
}
205+
206+
impl<'a, T> TryFrom<VecRef<'a>> for &'a [T]
207+
where
208+
T: VectorFormat,
209+
{
210+
type Error = Error;
211+
212+
fn try_from(s: VecRef) -> Result<&[T]> {
213+
T::vec_ref_to_slice(&s)
214+
}
215+
}
216+
217+
impl ToSqlNull for VecRef<'_> {
218+
fn oratype_for_null(_conn: &Connection) -> Result<OracleType> {
219+
Ok(OracleType::Vector(0, VecFmt::Flexible))
220+
}
221+
}
222+
223+
impl ToSql for VecRef<'_> {
224+
fn oratype(&self, _conn: &Connection) -> Result<OracleType> {
225+
Ok(OracleType::Vector(0, VecFmt::Flexible))
226+
}
227+
fn to_sql(&self, val: &mut SqlValue) -> Result<()> {
228+
val.set_vec_ref(self, "VecRef")
229+
}
230+
}
231+
232+
/// Trait for vector dimension element type
233+
///
234+
/// This trait is sealed and cannot be implemented for types outside of the `oracle` crate.
235+
pub trait VectorFormat: private::Sealed + Sized {
236+
#[doc(hidden)]
237+
fn slice_to_vec_ref(s: &[Self]) -> VecRef;
238+
#[doc(hidden)]
239+
fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]>;
240+
}
241+
242+
/// For the element type of Oracle data type `VECTOR(FLOAT32)`
243+
impl VectorFormat for f32 {
244+
fn slice_to_vec_ref(s: &[Self]) -> VecRef {
245+
VecRef::Float32(s)
246+
}
247+
fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> {
248+
match s {
249+
VecRef::Float32(s) => Ok(s),
250+
_ => Err(Error::new(
251+
ErrorKind::InvalidTypeConversion,
252+
format!("Could not convert {} to &[f32]", s.oracle_type()),
253+
)),
254+
}
255+
}
256+
}
257+
258+
/// For the element type of Oracle data type `VECTOR(FLOAT64)`
259+
impl VectorFormat for f64 {
260+
fn slice_to_vec_ref(s: &[Self]) -> VecRef {
261+
VecRef::Float64(s)
262+
}
263+
fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> {
264+
match s {
265+
VecRef::Float64(s) => Ok(s),
266+
_ => Err(Error::new(
267+
ErrorKind::InvalidTypeConversion,
268+
format!("Could not convert {} to &[f64]", s.oracle_type()),
269+
)),
270+
}
271+
}
272+
}
273+
274+
/// For the element type of Oracle data type `VECTOR(INT8)`
275+
impl VectorFormat for i8 {
276+
fn slice_to_vec_ref(s: &[Self]) -> VecRef {
277+
VecRef::Int8(s)
278+
}
279+
fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> {
280+
match s {
281+
VecRef::Int8(s) => Ok(s),
282+
_ => Err(Error::new(
283+
ErrorKind::InvalidTypeConversion,
284+
format!("Could not convert {} to &[i8]", s.oracle_type()),
285+
)),
286+
}
287+
}
288+
}
289+
290+
/// For the element type of Oracle data type `VECTOR(BINARY)`
291+
impl VectorFormat for u8 {
292+
fn slice_to_vec_ref(s: &[Self]) -> VecRef {
293+
VecRef::Binary(s)
294+
}
295+
fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> {
296+
match s {
297+
VecRef::Binary(s) => Ok(s),
298+
_ => Err(Error::new(
299+
ErrorKind::InvalidTypeConversion,
300+
format!("Could not convert {} to &[u8]", s.oracle_type()),
301+
)),
302+
}
303+
}
304+
}
305+
78306
#[cfg(test)]
79307
mod tests {
80308
use crate::sql_type::vector::VecFmt;
309+
use crate::sql_type::vector::VecRef;
81310
use crate::sql_type::OracleType;
82311
use crate::test_util;
83312
use crate::Result;
@@ -146,4 +375,44 @@ mod tests {
146375
}
147376
Ok(())
148377
}
378+
379+
#[test]
380+
fn to_sql() -> Result<()> {
381+
let conn = test_util::connect()?;
382+
383+
if !test_util::check_version(&conn, &test_util::VER23, &test_util::VER23)? {
384+
return Ok(());
385+
}
386+
let binary_vec = test_util::check_version(&conn, &test_util::VER23_5, &test_util::VER23_5)?;
387+
conn.execute("delete from test_vector_type", &[])?;
388+
let mut stmt = conn
389+
.statement("insert into test_vector_type(id, vec) values(:1, :2)")
390+
.build()?;
391+
let mut expected_data = vec![];
392+
stmt.execute(&[&1, &VecRef::Float32(&[1.0, 1.25, 1.5])])?;
393+
expected_data.push((1, "FLOAT32", "[1.0E+000,1.25E+000,1.5E+000]"));
394+
stmt.execute(&[&2, &VecRef::Float64(&[2.0, 2.25, 2.5])])?;
395+
expected_data.push((2, "FLOAT64", "[2.0E+000,2.25E+000,2.5E+000]"));
396+
stmt.execute(&[&3, &VecRef::Int8(&[3, 4, 5])])?;
397+
expected_data.push((3, "INT8", "[3,4,5]"));
398+
if binary_vec {
399+
stmt.execute(&[&4, &VecRef::Binary(&[6, 7, 8])])?;
400+
expected_data.push((4, "BINARY", "[6,7,8]"));
401+
}
402+
let mut index = 0;
403+
for row_result in conn.query_as::<(i32, String, String)>(
404+
"select id, vector_dimension_format(vec), from_vector(vec) from test_vector_type order by id",
405+
&[],
406+
)? {
407+
let row = row_result?;
408+
assert!(index < expected_data.len());
409+
let data = &expected_data[index];
410+
assert_eq!(row.0, data.0);
411+
assert_eq!(row.1, data.1);
412+
assert_eq!(row.2, data.2);
413+
index += 1;
414+
}
415+
assert_eq!(index, expected_data.len());
416+
Ok(())
417+
}
149418
}

src/sql_value.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
use crate::chkerr;
1717
use crate::connection::Conn;
18+
use crate::sql_type::vector::VecRef;
1819
use crate::sql_type::Bfile;
1920
use crate::sql_type::Blob;
2021
use crate::sql_type::Clob;
@@ -898,6 +899,17 @@ impl SqlValue<'_> {
898899
Ok(())
899900
}
900901

902+
fn set_vec_ref_unchecked(&mut self, vec: &VecRef) -> Result<()> {
903+
let data = self.data()?;
904+
let mut vec = vec.to_dpi()?;
905+
chkerr!(
906+
self.ctxt(),
907+
dpiVector_setValue(data.value.asVector, &mut vec)
908+
);
909+
data.isNull = 0;
910+
Ok(())
911+
}
912+
901913
pub(crate) fn clone_except_fetch_array_buffer(&self) -> Result<SqlValue<'static>> {
902914
if let DpiData::Var(ref var) = self.data {
903915
Ok(SqlValue {
@@ -1370,6 +1382,13 @@ impl SqlValue<'_> {
13701382
self.invalid_conversion_from_rust_type("Nclob")
13711383
}
13721384

1385+
pub(crate) fn set_vec_ref(&mut self, val: &VecRef, typename: &str) -> Result<()> {
1386+
match self.native_type {
1387+
NativeType::Vector => self.set_vec_ref_unchecked(val),
1388+
_ => self.invalid_conversion_from_rust_type(typename),
1389+
}
1390+
}
1391+
13731392
pub(crate) fn clone_with_narrow_lifetime(&self) -> Result<SqlValue> {
13741393
Ok(SqlValue {
13751394
conn: self.conn.clone(),

0 commit comments

Comments
 (0)