Skip to content

Commit e4cb337

Browse files
andygrovehimadripalhimadripal
authored
fix: decimal conversion looses value on lower precision (#6836) (#6936)
* decimal conversion looses value on lower precision, throws error now on overflow. * fix review comments and fix formatting. * for simple case of equal scale and bigger precision, no conversion needed. revert whitespace changes formatting check --------- Co-authored-by: Himadri Pal <[email protected]> Co-authored-by: himadripal <[email protected]>
1 parent 3366cb8 commit e4cb337

File tree

2 files changed

+116
-32
lines changed

2 files changed

+116
-32
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,13 @@ where
111111
O::Native::from_decimal(adjusted)
112112
};
113113

114-
Ok(match cast_options.safe {
115-
true => array.unary_opt(f),
116-
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
114+
Ok(if cast_options.safe {
115+
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
116+
} else {
117+
array.try_unary(|x| {
118+
f(x).ok_or_else(|| error(x))
119+
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
120+
})?
117121
})
118122
}
119123

@@ -137,15 +141,20 @@ where
137141

138142
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
139143

140-
Ok(match cast_options.safe {
141-
true => array.unary_opt(f),
142-
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
144+
Ok(if cast_options.safe {
145+
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
146+
} else {
147+
array.try_unary(|x| {
148+
f(x).ok_or_else(|| error(x))
149+
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
150+
})?
143151
})
144152
}
145153

146154
// Only support one type of decimal cast operations
147155
pub(crate) fn cast_decimal_to_decimal_same_type<T>(
148156
array: &PrimitiveArray<T>,
157+
input_precision: u8,
149158
input_scale: i8,
150159
output_precision: u8,
151160
output_scale: i8,
@@ -155,29 +164,27 @@ where
155164
T: DecimalType,
156165
T::Native: DecimalCast + ArrowNativeTypeOp,
157166
{
158-
let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
159-
Ordering::Equal => {
160-
// the scale doesn't change, the native value don't need to be changed
167+
let array: PrimitiveArray<T> =
168+
if input_scale == output_scale && input_precision <= output_precision {
161169
array.clone()
162-
}
163-
Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
164-
array,
165-
input_scale,
166-
output_precision,
167-
output_scale,
168-
cast_options,
169-
)?,
170-
Ordering::Less => {
171-
// input_scale < output_scale
170+
} else if input_scale < output_scale {
171+
// the scale doesn't change, but precision may change and cause overflow
172172
convert_to_bigger_or_equal_scale_decimal::<T, T>(
173173
array,
174174
input_scale,
175175
output_precision,
176176
output_scale,
177177
cast_options,
178178
)?
179-
}
180-
};
179+
} else {
180+
convert_to_smaller_scale_decimal::<T, T>(
181+
array,
182+
input_scale,
183+
output_precision,
184+
output_scale,
185+
cast_options,
186+
)?
187+
};
181188

182189
Ok(Arc::new(array.with_precision_and_scale(
183190
output_precision,

arrow-cast/src/cast/mod.rs

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -824,18 +824,20 @@ pub fn cast_with_options(
824824
(Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => {
825825
cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned())
826826
}
827-
(Decimal128(_, s1), Decimal128(p2, s2)) => {
827+
(Decimal128(p1, s1), Decimal128(p2, s2)) => {
828828
cast_decimal_to_decimal_same_type::<Decimal128Type>(
829829
array.as_primitive(),
830+
*p1,
830831
*s1,
831832
*p2,
832833
*s2,
833834
cast_options,
834835
)
835836
}
836-
(Decimal256(_, s1), Decimal256(p2, s2)) => {
837+
(Decimal256(p1, s1), Decimal256(p2, s2)) => {
837838
cast_decimal_to_decimal_same_type::<Decimal256Type>(
838839
array.as_primitive(),
840+
*p1,
839841
*s1,
840842
*p2,
841843
*s2,
@@ -2681,13 +2683,16 @@ mod tests {
26812683
// negative test
26822684
let array = vec![Some(123456), None];
26832685
let array = create_decimal_array(array, 10, 0).unwrap();
2684-
let result = cast(&array, &DataType::Decimal128(2, 2));
2685-
assert!(result.is_ok());
2686-
let array = result.unwrap();
2687-
let array: &Decimal128Array = array.as_primitive();
2688-
let err = array.validate_decimal_precision(2);
2686+
let result_safe = cast(&array, &DataType::Decimal128(2, 2));
2687+
assert!(result_safe.is_ok());
2688+
let options = CastOptions {
2689+
safe: false,
2690+
..Default::default()
2691+
};
2692+
2693+
let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options);
26892694
assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99",
2690-
err.unwrap_err().to_string());
2695+
result_unsafe.unwrap_err().to_string());
26912696
}
26922697

26932698
#[test]
@@ -8388,7 +8393,7 @@ mod tests {
83888393
let input_type = DataType::Decimal128(10, 3);
83898394
let output_type = DataType::Decimal256(10, 5);
83908395
assert!(can_cast_types(&input_type, &output_type));
8391-
let array = vec![Some(i128::MAX), Some(i128::MIN)];
8396+
let array = vec![Some(123456), Some(-123456)];
83928397
let input_decimal_array = create_decimal_array(array, 10, 3).unwrap();
83938398
let array = Arc::new(input_decimal_array) as ArrayRef;
83948399

@@ -8398,8 +8403,8 @@ mod tests {
83988403
Decimal256Array,
83998404
&output_type,
84008405
vec![
8401-
Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)),
8402-
Some(i256::from_i128(i128::MIN).mul_wrapping(hundred))
8406+
Some(i256::from_i128(123456).mul_wrapping(hundred)),
8407+
Some(i256::from_i128(-123456).mul_wrapping(hundred))
84038408
]
84048409
);
84058410
}
@@ -9827,4 +9832,76 @@ mod tests {
98279832
"Cast non-nullable to non-nullable struct field returning null should fail",
98289833
);
98299834
}
9835+
9836+
#[test]
9837+
fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() {
9838+
let array = vec![Some(123456789)];
9839+
let array = create_decimal_array(array, 24, 2).unwrap();
9840+
println!("{:?}", array);
9841+
let input_type = DataType::Decimal128(24, 2);
9842+
let output_type = DataType::Decimal128(6, 2);
9843+
assert!(can_cast_types(&input_type, &output_type));
9844+
9845+
let options = CastOptions {
9846+
safe: false,
9847+
..Default::default()
9848+
};
9849+
let result = cast_with_options(&array, &output_type, &options);
9850+
assert_eq!(result.unwrap_err().to_string(),
9851+
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
9852+
}
9853+
9854+
#[test]
9855+
fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() {
9856+
let array = vec![Some(123456789)];
9857+
let array = create_decimal_array(array, 24, 2).unwrap();
9858+
println!("{:?}", array);
9859+
let input_type = DataType::Decimal128(24, 4);
9860+
let output_type = DataType::Decimal128(6, 2);
9861+
assert!(can_cast_types(&input_type, &output_type));
9862+
9863+
let options = CastOptions {
9864+
safe: false,
9865+
..Default::default()
9866+
};
9867+
let result = cast_with_options(&array, &output_type, &options);
9868+
assert_eq!(result.unwrap_err().to_string(),
9869+
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
9870+
}
9871+
9872+
#[test]
9873+
fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() {
9874+
let array = vec![Some(123456789)];
9875+
let array = create_decimal_array(array, 24, 2).unwrap();
9876+
println!("{:?}", array);
9877+
let input_type = DataType::Decimal128(24, 2);
9878+
let output_type = DataType::Decimal128(6, 3);
9879+
assert!(can_cast_types(&input_type, &output_type));
9880+
9881+
let options = CastOptions {
9882+
safe: false,
9883+
..Default::default()
9884+
};
9885+
let result = cast_with_options(&array, &output_type, &options);
9886+
assert_eq!(result.unwrap_err().to_string(),
9887+
"Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999");
9888+
}
9889+
9890+
#[test]
9891+
fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() {
9892+
let array = vec![Some(123456789)];
9893+
let array = create_decimal_array(array, 24, 2).unwrap();
9894+
println!("{:?}", array);
9895+
let input_type = DataType::Decimal128(24, 2);
9896+
let output_type = DataType::Decimal256(6, 2);
9897+
assert!(can_cast_types(&input_type, &output_type));
9898+
9899+
let options = CastOptions {
9900+
safe: false,
9901+
..Default::default()
9902+
};
9903+
let result = cast_with_options(&array, &output_type, &options);
9904+
assert_eq!(result.unwrap_err().to_string(),
9905+
"Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999");
9906+
}
98309907
}

0 commit comments

Comments
 (0)