Skip to content

Commit b66fb90

Browse files
alambtlm365
andauthored
[arrow-cast] Support cast numeric to string view (alternate) (#6816) (#6944)
* [arrow-cast] Support cast numeric to string view * fix test --------- Signed-off-by: Tai Le Manh <[email protected]> Co-authored-by: Tai Le Manh <[email protected]>
1 parent 955180b commit b66fb90

File tree

2 files changed

+207
-107
lines changed

2 files changed

+207
-107
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 183 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
182182
(Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) |
183183
// decimal to signed numeric
184184
(Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true,
185-
// decimal to Utf8
186-
(Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
185+
// decimal to string
186+
(Decimal128(_, _) | Decimal256(_, _), Utf8View | Utf8 | LargeUtf8) => true,
187187
// string to decimal
188188
(Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
189189
(Struct(from_fields), Struct(to_fields)) => {
@@ -232,6 +232,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
232232
(BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true,
233233
(Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
234234
(_, Utf8 | LargeUtf8) => from_type.is_primitive(),
235+
(_, Utf8View) => from_type.is_numeric(),
235236

236237
(_, Binary | LargeBinary) => from_type.is_integer(),
237238

@@ -919,6 +920,7 @@ pub fn cast_with_options(
919920
Float64 => cast_decimal_to_float::<Decimal128Type, Float64Type, _>(array, |x| {
920921
x as f64 / 10_f64.powi(*scale as i32)
921922
}),
923+
Utf8View => value_to_string_view(array, cast_options),
922924
Utf8 => value_to_string::<i32>(array, cast_options),
923925
LargeUtf8 => value_to_string::<i64>(array, cast_options),
924926
Null => Ok(new_null_array(to_type, array.len())),
@@ -984,6 +986,7 @@ pub fn cast_with_options(
984986
Float64 => cast_decimal_to_float::<Decimal256Type, Float64Type, _>(array, |x| {
985987
x.to_f64().unwrap() / 10_f64.powi(*scale as i32)
986988
}),
989+
Utf8View => value_to_string_view(array, cast_options),
987990
Utf8 => value_to_string::<i32>(array, cast_options),
988991
LargeUtf8 => value_to_string::<i64>(array, cast_options),
989992
Null => Ok(new_null_array(to_type, array.len())),
@@ -1464,6 +1467,9 @@ pub fn cast_with_options(
14641467
(BinaryView, _) => Err(ArrowError::CastError(format!(
14651468
"Casting from {from_type:?} to {to_type:?} not supported",
14661469
))),
1470+
(from_type, Utf8View) if from_type.is_numeric() => {
1471+
value_to_string_view(array, cast_options)
1472+
}
14671473
(from_type, LargeUtf8) if from_type.is_primitive() => {
14681474
value_to_string::<i64>(array, cast_options)
14691475
}
@@ -3712,6 +3718,55 @@ mod tests {
37123718
assert_eq!(10.0, c.value(3));
37133719
}
37143720

3721+
#[test]
3722+
fn test_cast_int_to_utf8view() {
3723+
let inputs = vec![
3724+
Arc::new(Int8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3725+
Arc::new(Int16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3726+
Arc::new(Int32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3727+
Arc::new(Int64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3728+
Arc::new(UInt8Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3729+
Arc::new(UInt16Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3730+
Arc::new(UInt32Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3731+
Arc::new(UInt64Array::from(vec![None, Some(8), Some(9), Some(10)])) as ArrayRef,
3732+
];
3733+
let expected: ArrayRef = Arc::new(StringViewArray::from(vec![
3734+
None,
3735+
Some("8"),
3736+
Some("9"),
3737+
Some("10"),
3738+
]));
3739+
3740+
for array in inputs {
3741+
assert!(can_cast_types(array.data_type(), &DataType::Utf8View));
3742+
let arr = cast(&array, &DataType::Utf8View).unwrap();
3743+
assert_eq!(expected.as_ref(), arr.as_ref());
3744+
}
3745+
}
3746+
3747+
#[test]
3748+
fn test_cast_float_to_utf8view() {
3749+
let inputs = vec![
3750+
Arc::new(Float16Array::from(vec![
3751+
Some(f16::from_f64(1.5)),
3752+
Some(f16::from_f64(2.5)),
3753+
None,
3754+
])) as ArrayRef,
3755+
Arc::new(Float32Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef,
3756+
Arc::new(Float64Array::from(vec![Some(1.5), Some(2.5), None])) as ArrayRef,
3757+
];
3758+
3759+
let expected: ArrayRef =
3760+
Arc::new(StringViewArray::from(vec![Some("1.5"), Some("2.5"), None]));
3761+
3762+
for array in inputs {
3763+
println!("type: {}", array.data_type());
3764+
assert!(can_cast_types(array.data_type(), &DataType::Utf8View));
3765+
let arr = cast(&array, &DataType::Utf8View).unwrap();
3766+
assert_eq!(expected.as_ref(), arr.as_ref());
3767+
}
3768+
}
3769+
37153770
#[test]
37163771
fn test_cast_utf8_to_i32() {
37173772
let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]);
@@ -5185,41 +5240,46 @@ mod tests {
51855240
assert_eq!("2018-12-25T00:00:00", c.value(1));
51865241
}
51875242

5243+
// Cast Timestamp to Utf8View is not supported yet
5244+
// TODO: Implement casting from Timestamp to Utf8View
5245+
// https://github.com/apache/arrow-rs/issues/6734
5246+
macro_rules! assert_cast_timestamp_to_string {
5247+
($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{
5248+
let out = cast(&$array, &$datatype).unwrap();
5249+
let actual = out
5250+
.as_any()
5251+
.downcast_ref::<$output_array_type>()
5252+
.unwrap()
5253+
.into_iter()
5254+
.collect::<Vec<_>>();
5255+
assert_eq!(actual, $expected);
5256+
}};
5257+
($array:expr, $datatype:expr, $output_array_type: ty, $options:expr, $expected:expr) => {{
5258+
let out = cast_with_options(&$array, &$datatype, &$options).unwrap();
5259+
let actual = out
5260+
.as_any()
5261+
.downcast_ref::<$output_array_type>()
5262+
.unwrap()
5263+
.into_iter()
5264+
.collect::<Vec<_>>();
5265+
assert_eq!(actual, $expected);
5266+
}};
5267+
}
5268+
51885269
#[test]
51895270
fn test_cast_timestamp_to_strings() {
51905271
// "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None
51915272
let array =
51925273
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]);
5193-
let out = cast(&array, &DataType::Utf8).unwrap();
5194-
let out = out
5195-
.as_any()
5196-
.downcast_ref::<StringArray>()
5197-
.unwrap()
5198-
.into_iter()
5199-
.collect::<Vec<_>>();
5200-
assert_eq!(
5201-
out,
5202-
vec![
5203-
Some("1997-05-19T00:00:03.005"),
5204-
Some("2018-12-25T00:00:02.001"),
5205-
None
5206-
]
5207-
);
5208-
let out = cast(&array, &DataType::LargeUtf8).unwrap();
5209-
let out = out
5210-
.as_any()
5211-
.downcast_ref::<LargeStringArray>()
5212-
.unwrap()
5213-
.into_iter()
5214-
.collect::<Vec<_>>();
5215-
assert_eq!(
5216-
out,
5217-
vec![
5218-
Some("1997-05-19T00:00:03.005"),
5219-
Some("2018-12-25T00:00:02.001"),
5220-
None
5221-
]
5222-
);
5274+
let expected = vec![
5275+
Some("1997-05-19T00:00:03.005"),
5276+
Some("2018-12-25T00:00:02.001"),
5277+
None,
5278+
];
5279+
5280+
// assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected);
5281+
assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected);
5282+
assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected);
52235283
}
52245284

52255285
#[test]
@@ -5232,73 +5292,53 @@ mod tests {
52325292
.with_timestamp_format(Some(ts_format))
52335293
.with_timestamp_tz_format(Some(ts_format)),
52345294
};
5295+
52355296
// "2018-12-25T00:00:02.001", "1997-05-19T00:00:03.005", None
52365297
let array_without_tz =
52375298
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None]);
5238-
let out = cast_with_options(&array_without_tz, &DataType::Utf8, &cast_options).unwrap();
5239-
let out = out
5240-
.as_any()
5241-
.downcast_ref::<StringArray>()
5242-
.unwrap()
5243-
.into_iter()
5244-
.collect::<Vec<_>>();
5245-
assert_eq!(
5246-
out,
5247-
vec![
5248-
Some("1997-05-19 00:00:03.005000"),
5249-
Some("2018-12-25 00:00:02.001000"),
5250-
None
5251-
]
5299+
let expected = vec![
5300+
Some("1997-05-19 00:00:03.005000"),
5301+
Some("2018-12-25 00:00:02.001000"),
5302+
None,
5303+
];
5304+
// assert_cast_timestamp_to_string!(array_without_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
5305+
assert_cast_timestamp_to_string!(
5306+
array_without_tz,
5307+
DataType::Utf8,
5308+
StringArray,
5309+
cast_options,
5310+
expected
52525311
);
5253-
let out =
5254-
cast_with_options(&array_without_tz, &DataType::LargeUtf8, &cast_options).unwrap();
5255-
let out = out
5256-
.as_any()
5257-
.downcast_ref::<LargeStringArray>()
5258-
.unwrap()
5259-
.into_iter()
5260-
.collect::<Vec<_>>();
5261-
assert_eq!(
5262-
out,
5263-
vec![
5264-
Some("1997-05-19 00:00:03.005000"),
5265-
Some("2018-12-25 00:00:02.001000"),
5266-
None
5267-
]
5312+
assert_cast_timestamp_to_string!(
5313+
array_without_tz,
5314+
DataType::LargeUtf8,
5315+
LargeStringArray,
5316+
cast_options,
5317+
expected
52685318
);
52695319

52705320
let array_with_tz =
52715321
TimestampMillisecondArray::from(vec![Some(864000003005), Some(1545696002001), None])
52725322
.with_timezone(tz.to_string());
5273-
let out = cast_with_options(&array_with_tz, &DataType::Utf8, &cast_options).unwrap();
5274-
let out = out
5275-
.as_any()
5276-
.downcast_ref::<StringArray>()
5277-
.unwrap()
5278-
.into_iter()
5279-
.collect::<Vec<_>>();
5280-
assert_eq!(
5281-
out,
5282-
vec![
5283-
Some("1997-05-19 05:45:03.005000"),
5284-
Some("2018-12-25 05:45:02.001000"),
5285-
None
5286-
]
5323+
let expected = vec![
5324+
Some("1997-05-19 05:45:03.005000"),
5325+
Some("2018-12-25 05:45:02.001000"),
5326+
None,
5327+
];
5328+
// assert_cast_timestamp_to_string!(array_with_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
5329+
assert_cast_timestamp_to_string!(
5330+
array_with_tz,
5331+
DataType::Utf8,
5332+
StringArray,
5333+
cast_options,
5334+
expected
52875335
);
5288-
let out = cast_with_options(&array_with_tz, &DataType::LargeUtf8, &cast_options).unwrap();
5289-
let out = out
5290-
.as_any()
5291-
.downcast_ref::<LargeStringArray>()
5292-
.unwrap()
5293-
.into_iter()
5294-
.collect::<Vec<_>>();
5295-
assert_eq!(
5296-
out,
5297-
vec![
5298-
Some("1997-05-19 05:45:03.005000"),
5299-
Some("2018-12-25 05:45:02.001000"),
5300-
None
5301-
]
5336+
assert_cast_timestamp_to_string!(
5337+
array_with_tz,
5338+
DataType::LargeUtf8,
5339+
LargeStringArray,
5340+
cast_options,
5341+
expected
53025342
);
53035343
}
53045344

@@ -9153,26 +9193,51 @@ mod tests {
91539193
}
91549194

91559195
#[test]
9156-
fn test_cast_decimal_to_utf8() {
9196+
fn test_cast_decimal_to_string() {
9197+
assert!(can_cast_types(
9198+
&DataType::Decimal128(10, 4),
9199+
&DataType::Utf8View
9200+
));
9201+
assert!(can_cast_types(
9202+
&DataType::Decimal256(38, 10),
9203+
&DataType::Utf8View
9204+
));
9205+
9206+
macro_rules! assert_decimal_values {
9207+
($array:expr) => {
9208+
let c = $array;
9209+
assert_eq!("1123.454", c.value(0));
9210+
assert_eq!("2123.456", c.value(1));
9211+
assert_eq!("-3123.453", c.value(2));
9212+
assert_eq!("-3123.456", c.value(3));
9213+
assert_eq!("0.000", c.value(4));
9214+
assert_eq!("0.123", c.value(5));
9215+
assert_eq!("1234.567", c.value(6));
9216+
assert_eq!("-1234.567", c.value(7));
9217+
assert!(c.is_null(8));
9218+
};
9219+
}
9220+
91579221
fn test_decimal_to_string<IN: ArrowPrimitiveType, OffsetSize: OffsetSizeTrait>(
91589222
output_type: DataType,
91599223
array: PrimitiveArray<IN>,
91609224
) {
91619225
let b = cast(&array, &output_type).unwrap();
91629226

91639227
assert_eq!(b.data_type(), &output_type);
9164-
let c = b.as_string::<OffsetSize>();
9165-
9166-
assert_eq!("1123.454", c.value(0));
9167-
assert_eq!("2123.456", c.value(1));
9168-
assert_eq!("-3123.453", c.value(2));
9169-
assert_eq!("-3123.456", c.value(3));
9170-
assert_eq!("0.000", c.value(4));
9171-
assert_eq!("0.123", c.value(5));
9172-
assert_eq!("1234.567", c.value(6));
9173-
assert_eq!("-1234.567", c.value(7));
9174-
assert!(c.is_null(8));
9228+
match b.data_type() {
9229+
DataType::Utf8View => {
9230+
let c = b.as_string_view();
9231+
assert_decimal_values!(c);
9232+
}
9233+
DataType::Utf8 | DataType::LargeUtf8 => {
9234+
let c = b.as_string::<OffsetSize>();
9235+
assert_decimal_values!(c);
9236+
}
9237+
_ => (),
9238+
}
91759239
}
9240+
91769241
let array128: Vec<Option<i128>> = vec![
91779242
Some(1123454),
91789243
Some(2123456),
@@ -9184,22 +9249,33 @@ mod tests {
91849249
Some(-123456789),
91859250
None,
91869251
];
9252+
let array256: Vec<Option<i256>> = array128
9253+
.iter()
9254+
.map(|num| num.map(i256::from_i128))
9255+
.collect();
91879256

9188-
let array256: Vec<Option<i256>> = array128.iter().map(|v| v.map(i256::from_i128)).collect();
9189-
9190-
test_decimal_to_string::<arrow_array::types::Decimal128Type, i32>(
9257+
test_decimal_to_string::<Decimal128Type, i32>(
9258+
DataType::Utf8View,
9259+
create_decimal_array(array128.clone(), 7, 3).unwrap(),
9260+
);
9261+
test_decimal_to_string::<Decimal128Type, i32>(
91919262
DataType::Utf8,
91929263
create_decimal_array(array128.clone(), 7, 3).unwrap(),
91939264
);
9194-
test_decimal_to_string::<arrow_array::types::Decimal128Type, i64>(
9265+
test_decimal_to_string::<Decimal128Type, i64>(
91959266
DataType::LargeUtf8,
91969267
create_decimal_array(array128, 7, 3).unwrap(),
91979268
);
9198-
test_decimal_to_string::<arrow_array::types::Decimal256Type, i32>(
9269+
9270+
test_decimal_to_string::<Decimal256Type, i32>(
9271+
DataType::Utf8View,
9272+
create_decimal256_array(array256.clone(), 7, 3).unwrap(),
9273+
);
9274+
test_decimal_to_string::<Decimal256Type, i32>(
91999275
DataType::Utf8,
92009276
create_decimal256_array(array256.clone(), 7, 3).unwrap(),
92019277
);
9202-
test_decimal_to_string::<arrow_array::types::Decimal256Type, i64>(
9278+
test_decimal_to_string::<Decimal256Type, i64>(
92039279
DataType::LargeUtf8,
92049280
create_decimal256_array(array256, 7, 3).unwrap(),
92059281
);

0 commit comments

Comments
 (0)