|
13 | 13 | // (ii) the Apache License v 2.0. (http://www.apache.org/licenses/LICENSE-2.0)
|
14 | 14 | //-----------------------------------------------------------------------------
|
15 | 15 |
|
| 16 | +use crate::private; |
| 17 | +use crate::sql_type::OracleType; |
| 18 | +use crate::sql_type::SqlValue; |
| 19 | +use crate::sql_type::ToSql; |
| 20 | +use crate::sql_type::ToSqlNull; |
| 21 | +use crate::Connection; |
16 | 22 | use crate::Error;
|
| 23 | +use crate::ErrorKind; |
17 | 24 | use crate::Result;
|
18 | 25 | use odpic_sys::*;
|
19 | 26 | use std::fmt;
|
| 27 | +use std::os::raw::c_void; |
20 | 28 |
|
21 | 29 | /// Vector dimension element format
|
22 | 30 | ///
|
@@ -75,9 +83,230 @@ impl fmt::Display for VecFmt {
|
75 | 83 | }
|
76 | 84 | }
|
77 | 85 |
|
| 86 | +/// Reference to vector dimension elements |
| 87 | +/// |
| 88 | +/// See the [module-level documentation](index.html) for more. |
| 89 | +#[derive(Clone, Debug, PartialEq)] |
| 90 | +#[non_exhaustive] |
| 91 | +pub enum VecRef<'a> { |
| 92 | + /// Wraps `[f32]` slice data as Oracle data type `VECTOR(FLOAT32)` |
| 93 | + Float32(&'a [f32]), |
| 94 | + /// Wraps `[f64]` slice data as Oracle data type `VECTOR(FLOAT64)` |
| 95 | + Float64(&'a [f64]), |
| 96 | + /// Wraps `[i8]` slice data as Oracle data type `VECTOR(INT8)` |
| 97 | + Int8(&'a [i8]), |
| 98 | + /// Wraps `[u8]` slice data as Oracle data type `VECTOR(BINARY)` |
| 99 | + Binary(&'a [u8]), |
| 100 | +} |
| 101 | + |
| 102 | +impl VecRef<'_> { |
| 103 | + pub(crate) fn to_dpi(&self) -> Result<dpiVectorInfo> { |
| 104 | + match self { |
| 105 | + VecRef::Float32(slice) => Ok(dpiVectorInfo { |
| 106 | + format: DPI_VECTOR_FORMAT_FLOAT32 as u8, |
| 107 | + numDimensions: slice.len().try_into()?, |
| 108 | + dimensionSize: 4, |
| 109 | + dimensions: dpiVectorDimensionBuffer { |
| 110 | + asPtr: slice.as_ptr() as *mut c_void, |
| 111 | + }, |
| 112 | + }), |
| 113 | + VecRef::Float64(slice) => Ok(dpiVectorInfo { |
| 114 | + format: DPI_VECTOR_FORMAT_FLOAT64 as u8, |
| 115 | + numDimensions: slice.len().try_into()?, |
| 116 | + dimensionSize: 8, |
| 117 | + dimensions: dpiVectorDimensionBuffer { |
| 118 | + asPtr: slice.as_ptr() as *mut c_void, |
| 119 | + }, |
| 120 | + }), |
| 121 | + VecRef::Int8(slice) => Ok(dpiVectorInfo { |
| 122 | + format: DPI_VECTOR_FORMAT_INT8 as u8, |
| 123 | + numDimensions: slice.len().try_into()?, |
| 124 | + dimensionSize: 1, |
| 125 | + dimensions: dpiVectorDimensionBuffer { |
| 126 | + asPtr: slice.as_ptr() as *mut c_void, |
| 127 | + }, |
| 128 | + }), |
| 129 | + VecRef::Binary(slice) => Ok(dpiVectorInfo { |
| 130 | + format: DPI_VECTOR_FORMAT_BINARY as u8, |
| 131 | + numDimensions: (slice.len() * 8).try_into()?, |
| 132 | + dimensionSize: 1, |
| 133 | + dimensions: dpiVectorDimensionBuffer { |
| 134 | + asPtr: slice.as_ptr() as *mut c_void, |
| 135 | + }, |
| 136 | + }), |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + /// Returns vector dimension element format |
| 141 | + /// |
| 142 | + /// # Examples |
| 143 | + /// |
| 144 | + /// ``` |
| 145 | + /// # use oracle::sql_type::vector::VecFmt; |
| 146 | + /// # use oracle::sql_type::vector::VecRef; |
| 147 | + /// // Refernce to float32 vector data. |
| 148 | + /// let vec_ref = VecRef::Float32(&[0.001, 3.0, 5.3]); |
| 149 | + /// |
| 150 | + /// assert_eq!(vec_ref.format(), VecFmt::Float32); |
| 151 | + /// # Ok::<(), Box<dyn std::error::Error>>(()) |
| 152 | + /// ``` |
| 153 | + pub fn format(&self) -> VecFmt { |
| 154 | + match self { |
| 155 | + VecRef::Float32(_) => VecFmt::Float32, |
| 156 | + VecRef::Float64(_) => VecFmt::Float64, |
| 157 | + VecRef::Int8(_) => VecFmt::Int8, |
| 158 | + VecRef::Binary(_) => VecFmt::Binary, |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + /// Gets the containing vector data as slice |
| 163 | + /// |
| 164 | + /// # Examples |
| 165 | + /// |
| 166 | + /// ``` |
| 167 | + /// # use oracle::sql_type::vector::VecRef; |
| 168 | + /// // Refernce to float32 vector data. |
| 169 | + /// let vec_ref = VecRef::Float32(&[0.001, 3.0, 5.3]); |
| 170 | + /// |
| 171 | + /// // Gets as a slice of [f32] |
| 172 | + /// assert_eq!(vec_ref.as_slice::<f32>()?, &[0.001, 3.0, 5.3]); |
| 173 | + /// |
| 174 | + /// // Errors for other vector dimension types. |
| 175 | + /// assert!(vec_ref.as_slice::<f64>().is_err()); |
| 176 | + /// assert!(vec_ref.as_slice::<i8>().is_err()); |
| 177 | + /// assert!(vec_ref.as_slice::<u8>().is_err()); |
| 178 | + /// # Ok::<(), Box<dyn std::error::Error>>(()) |
| 179 | + /// ``` |
| 180 | + pub fn as_slice<T>(&self) -> Result<&[T]> |
| 181 | + where |
| 182 | + T: VectorFormat, |
| 183 | + { |
| 184 | + T::vec_ref_to_slice(self) |
| 185 | + } |
| 186 | + |
| 187 | + fn oracle_type(&self) -> OracleType { |
| 188 | + match self { |
| 189 | + VecRef::Float32(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float32), |
| 190 | + VecRef::Float64(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Float64), |
| 191 | + VecRef::Int8(slice) => OracleType::Vector(slice.len() as u32, VecFmt::Int8), |
| 192 | + VecRef::Binary(slice) => OracleType::Vector(slice.len() as u32 * 8, VecFmt::Binary), |
| 193 | + } |
| 194 | + } |
| 195 | +} |
| 196 | + |
| 197 | +impl<'a, T> From<&'a [T]> for VecRef<'a> |
| 198 | +where |
| 199 | + T: VectorFormat, |
| 200 | +{ |
| 201 | + fn from(s: &'a [T]) -> VecRef<'a> { |
| 202 | + T::slice_to_vec_ref(s) |
| 203 | + } |
| 204 | +} |
| 205 | + |
| 206 | +impl<'a, T> TryFrom<VecRef<'a>> for &'a [T] |
| 207 | +where |
| 208 | + T: VectorFormat, |
| 209 | +{ |
| 210 | + type Error = Error; |
| 211 | + |
| 212 | + fn try_from(s: VecRef) -> Result<&[T]> { |
| 213 | + T::vec_ref_to_slice(&s) |
| 214 | + } |
| 215 | +} |
| 216 | + |
| 217 | +impl ToSqlNull for VecRef<'_> { |
| 218 | + fn oratype_for_null(_conn: &Connection) -> Result<OracleType> { |
| 219 | + Ok(OracleType::Vector(0, VecFmt::Flexible)) |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +impl ToSql for VecRef<'_> { |
| 224 | + fn oratype(&self, _conn: &Connection) -> Result<OracleType> { |
| 225 | + Ok(OracleType::Vector(0, VecFmt::Flexible)) |
| 226 | + } |
| 227 | + fn to_sql(&self, val: &mut SqlValue) -> Result<()> { |
| 228 | + val.set_vec_ref(self, "VecRef") |
| 229 | + } |
| 230 | +} |
| 231 | + |
| 232 | +/// Trait for vector dimension element type |
| 233 | +/// |
| 234 | +/// This trait is sealed and cannot be implemented for types outside of the `oracle` crate. |
| 235 | +pub trait VectorFormat: private::Sealed + Sized { |
| 236 | + #[doc(hidden)] |
| 237 | + fn slice_to_vec_ref(s: &[Self]) -> VecRef; |
| 238 | + #[doc(hidden)] |
| 239 | + fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]>; |
| 240 | +} |
| 241 | + |
| 242 | +/// For the element type of Oracle data type `VECTOR(FLOAT32)` |
| 243 | +impl VectorFormat for f32 { |
| 244 | + fn slice_to_vec_ref(s: &[Self]) -> VecRef { |
| 245 | + VecRef::Float32(s) |
| 246 | + } |
| 247 | + fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> { |
| 248 | + match s { |
| 249 | + VecRef::Float32(s) => Ok(s), |
| 250 | + _ => Err(Error::new( |
| 251 | + ErrorKind::InvalidTypeConversion, |
| 252 | + format!("Could not convert {} to &[f32]", s.oracle_type()), |
| 253 | + )), |
| 254 | + } |
| 255 | + } |
| 256 | +} |
| 257 | + |
| 258 | +/// For the element type of Oracle data type `VECTOR(FLOAT64)` |
| 259 | +impl VectorFormat for f64 { |
| 260 | + fn slice_to_vec_ref(s: &[Self]) -> VecRef { |
| 261 | + VecRef::Float64(s) |
| 262 | + } |
| 263 | + fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> { |
| 264 | + match s { |
| 265 | + VecRef::Float64(s) => Ok(s), |
| 266 | + _ => Err(Error::new( |
| 267 | + ErrorKind::InvalidTypeConversion, |
| 268 | + format!("Could not convert {} to &[f64]", s.oracle_type()), |
| 269 | + )), |
| 270 | + } |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +/// For the element type of Oracle data type `VECTOR(INT8)` |
| 275 | +impl VectorFormat for i8 { |
| 276 | + fn slice_to_vec_ref(s: &[Self]) -> VecRef { |
| 277 | + VecRef::Int8(s) |
| 278 | + } |
| 279 | + fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> { |
| 280 | + match s { |
| 281 | + VecRef::Int8(s) => Ok(s), |
| 282 | + _ => Err(Error::new( |
| 283 | + ErrorKind::InvalidTypeConversion, |
| 284 | + format!("Could not convert {} to &[i8]", s.oracle_type()), |
| 285 | + )), |
| 286 | + } |
| 287 | + } |
| 288 | +} |
| 289 | + |
| 290 | +/// For the element type of Oracle data type `VECTOR(BINARY)` |
| 291 | +impl VectorFormat for u8 { |
| 292 | + fn slice_to_vec_ref(s: &[Self]) -> VecRef { |
| 293 | + VecRef::Binary(s) |
| 294 | + } |
| 295 | + fn vec_ref_to_slice<'a>(s: &VecRef<'a>) -> Result<&'a [Self]> { |
| 296 | + match s { |
| 297 | + VecRef::Binary(s) => Ok(s), |
| 298 | + _ => Err(Error::new( |
| 299 | + ErrorKind::InvalidTypeConversion, |
| 300 | + format!("Could not convert {} to &[u8]", s.oracle_type()), |
| 301 | + )), |
| 302 | + } |
| 303 | + } |
| 304 | +} |
| 305 | + |
78 | 306 | #[cfg(test)]
|
79 | 307 | mod tests {
|
80 | 308 | use crate::sql_type::vector::VecFmt;
|
| 309 | + use crate::sql_type::vector::VecRef; |
81 | 310 | use crate::sql_type::OracleType;
|
82 | 311 | use crate::test_util;
|
83 | 312 | use crate::Result;
|
@@ -146,4 +375,44 @@ mod tests {
|
146 | 375 | }
|
147 | 376 | Ok(())
|
148 | 377 | }
|
| 378 | + |
| 379 | + #[test] |
| 380 | + fn to_sql() -> Result<()> { |
| 381 | + let conn = test_util::connect()?; |
| 382 | + |
| 383 | + if !test_util::check_version(&conn, &test_util::VER23, &test_util::VER23)? { |
| 384 | + return Ok(()); |
| 385 | + } |
| 386 | + let binary_vec = test_util::check_version(&conn, &test_util::VER23_5, &test_util::VER23_5)?; |
| 387 | + conn.execute("delete from test_vector_type", &[])?; |
| 388 | + let mut stmt = conn |
| 389 | + .statement("insert into test_vector_type(id, vec) values(:1, :2)") |
| 390 | + .build()?; |
| 391 | + let mut expected_data = vec![]; |
| 392 | + stmt.execute(&[&1, &VecRef::Float32(&[1.0, 1.25, 1.5])])?; |
| 393 | + expected_data.push((1, "FLOAT32", "[1.0E+000,1.25E+000,1.5E+000]")); |
| 394 | + stmt.execute(&[&2, &VecRef::Float64(&[2.0, 2.25, 2.5])])?; |
| 395 | + expected_data.push((2, "FLOAT64", "[2.0E+000,2.25E+000,2.5E+000]")); |
| 396 | + stmt.execute(&[&3, &VecRef::Int8(&[3, 4, 5])])?; |
| 397 | + expected_data.push((3, "INT8", "[3,4,5]")); |
| 398 | + if binary_vec { |
| 399 | + stmt.execute(&[&4, &VecRef::Binary(&[6, 7, 8])])?; |
| 400 | + expected_data.push((4, "BINARY", "[6,7,8]")); |
| 401 | + } |
| 402 | + let mut index = 0; |
| 403 | + for row_result in conn.query_as::<(i32, String, String)>( |
| 404 | + "select id, vector_dimension_format(vec), from_vector(vec) from test_vector_type order by id", |
| 405 | + &[], |
| 406 | + )? { |
| 407 | + let row = row_result?; |
| 408 | + assert!(index < expected_data.len()); |
| 409 | + let data = &expected_data[index]; |
| 410 | + assert_eq!(row.0, data.0); |
| 411 | + assert_eq!(row.1, data.1); |
| 412 | + assert_eq!(row.2, data.2); |
| 413 | + index += 1; |
| 414 | + } |
| 415 | + assert_eq!(index, expected_data.len()); |
| 416 | + Ok(()) |
| 417 | + } |
149 | 418 | }
|
0 commit comments